#!/usr/bin/env python

# pylint: disable=bad-whitespace
# pylint: disable=multiple-statements
# pylint: disable=len-as-condition
# pylint: disable=bare-except
# pylint: disable=redefined-outer-name
# pylint: disable=unidiomatic-typecheck
# pylint: disable=too-many-return-statements
# pylint: disable=wildcard-import
# pylint: disable=unused-wildcard-import

# TODO:
# ! add option for padding
# - fix occasionally missing page numbers
# - treat large h-whitespace as separator
# - handle overlapping candidates
# - use cc distance statistics instead of character scale
# - page frame detection
# - read and use text image segmentation mask
# - pick up stragglers
# ? laplacian as well

from __future__ import print_function

import argparse
import glob
import os
import os.path
import sys
import traceback
from multiprocessing import Pool

import numpy as np
# sub-namespaces have been deprecated, but we like to use them anyway
from scipy import ndimage as measurements
from scipy.ndimage import gaussian_filter,uniform_filter,maximum_filter
from imageio import imwrite

import ocrolib
from ocrolib import psegutils,morph,sl
from ocrolib.exceptions import OcropusException
from ocrolib.toplevel import *

parser = argparse.ArgumentParser(add_help=False)

# error checking
group_error_checking = parser.add_argument_group('error checking')
group_error_checking.add_argument('-n','--nocheck',action="store_true",
                    help="disable error checking on inputs")
# limits
group_error_checking.add_argument('--minscale',type=float,default=12.0,
                    help='minimum scale permitted, default: %(default)s')
group_error_checking.add_argument('--maxlines',type=float,default=300,
                    help='maximum # lines permitted, default: %(default)s')

# scale parameters
group_scale = parser.add_argument_group('scale parameters')
group_scale.add_argument('--scale',type=float,default=0.0,
                    help='the basic scale of the document (roughly, xheight) 0=automatic, default: %(default)s')
group_scale.add_argument('--hscale',type=float,default=1.0,
                    help='non-standard scaling of horizontal parameters, default: %(default)s')
group_scale.add_argument('--vscale',type=float,default=1.0,
                    help='non-standard scaling of vertical parameters, default: %(default)s')

# line parameters
group_line = parser.add_argument_group('line parameters')
group_line.add_argument('--threshold',type=float,default=0.2,
                    help='baseline threshold, default: %(default)s')
group_line.add_argument('--noise',type=int,default=8,
                    help="noise threshold for removing small components from lines, default: %(default)s")
group_line.add_argument('--usegauss',action='store_true',
                    help='use gaussian instead of uniform, default: %(default)s')

# column parameters
group_column = parser.add_argument_group('column parameters')
group_column.add_argument('--maxseps',type=int,default=0,
                    help='maximum black column separators, default: %(default)s')
group_column.add_argument('--sepwiden',type=int,default=10,
                    help='widen black separators (to account for warping), default: %(default)s')
# Obsolete parameter for 'also check for black column separators'
# which can now be triggered simply by a positive maxseps value.
group_column.add_argument('-b','--blackseps',action="store_true",
                    help=argparse.SUPPRESS)

# whitespace column separators
group_column.add_argument('--maxcolseps',type=int,default=3,
                    help='maximum # whitespace column separators, default: %(default)s')
group_column.add_argument('--csminheight',type=float,default=10,
                    help='minimum column height (units=scale), default: %(default)s')
# Obsolete parameter for the 'minimum aspect ratio for column separators'
# used in the obsolete function compute_colseps_morph
group_column.add_argument('--csminaspect',type=float,default=1.1,
                    help=argparse.SUPPRESS)

# output parameters
group_output = parser.add_argument_group('output parameters')
group_output.add_argument('--gray',action='store_true',
                    help='output grayscale lines as well, which are extracted from the grayscale version of the pages, default: %(default)s')
group_output.add_argument('-p','--pad',type=int,default=3,
                    help='padding for extracted lines, default: %(default)s')
group_output.add_argument('-e','--expand',type=int,default=3,
                    help='expand mask for grayscale extraction, default: %(default)s')

# other parameters
group_others = parser.add_argument_group('others')
group_others.add_argument('-q','--quiet',action='store_true',
                    help='be less verbose, default: %(default)s')
group_others.add_argument('-Q','--parallel',type=int,default=0,
                    help="number of CPUs to use")
group_others.add_argument('-d','--debug',action="store_true")
group_others.add_argument("-h", "--help", action="help", help="show this help message and exit")

# input files
parser.add_argument('files',nargs='+')

args = parser.parse_args()
args.files = ocrolib.glob_all(args.files)

def find(condition):
    "Return the indices where ravel(condition) is true"
    res, = np.nonzero(np.ravel(condition))
    return res
    
def norm_max(v):
    return v/np.amax(v)


def check_page(image):
    if len(image.shape)==3: return "input image is color image %s"%(image.shape,)
    if np.mean(image)<np.median(image): return "image may be inverted"
    h,w = image.shape
    if h<600: return "image not tall enough for a page image %s"%(image.shape,)
    if h>10000: return "image too tall for a page image %s"%(image.shape,)
    if w<600: return "image too narrow for a page image %s"%(image.shape,)
    if w>10000: return "line too wide for a page image %s"%(image.shape,)
    slots = int(w*h*1.0/(30*30))
    _,ncomps = measurements.label(image>np.mean(image))
    if ncomps<10: return "too few connected components for a page image (got %d)"%(ncomps,)
    if ncomps>slots: return "too many connnected components for a page image (%d > %d)"%(ncomps,slots)
    return None


def print_info(*objs):
    print("INFO: ", *objs, file=sys.stdout)


def print_error(*objs):
    print("ERROR: ", *objs, file=sys.stderr)


if len(args.files)<1:
    parser.print_help()
    sys.exit(0)

print_info("")
print_info("#"*10,(" ".join(sys.argv))[:60])
print_info("")

if args.parallel>1:
    args.quiet = 1


def B(a):
    if a.dtype==np.dtype('B'): return a
    return np.array(a,'B')


def DSAVE(title, image):
    if not args.debug: return
    if type(image)==list:
        assert len(image)==3
        image = np.transpose(np.array(image),[1,2,0])
    fname = "_"+title+".png"
    print_info("debug " + fname)
    imwrite(fname, image.astype('float'))


################################################################
### Column finding.
###
### This attempts to find column separators, either as extended
### vertical black lines or extended vertical whitespace.
### It will work fairly well in simple cases, but for unusual
### documents, you need to tune the parameters or use a mask.
################################################################

def compute_separators_morph(binary,scale):
    """Finds vertical black lines corresponding to column separators."""
    d0 = int(max(5,scale/4))
    d1 = int(max(5,scale))+args.sepwiden
    thick = morph.r_dilation(binary,(d0,d1))
    vert = morph.rb_opening(thick,(10*scale,1))
    vert = morph.r_erosion(vert,(d0//2,args.sepwiden))
    vert = morph.select_regions(vert,sl.dim1,min=3,nbest=2*args.maxseps)
    vert = morph.select_regions(vert,sl.dim0,min=20*scale,nbest=args.maxseps)
    return vert


def compute_colseps_morph(binary,scale,maxseps=3,minheight=20,maxwidth=5):
    """Finds extended vertical whitespace corresponding to column separators
    using morphological operations."""
    boxmap = psegutils.compute_boxmap(binary,scale,dtype='B')
    bounds = morph.rb_closing(B(boxmap),(int(5*scale),int(5*scale)))
    bounds = np.maximum(B(1-bounds),B(boxmap))
    cols = 1-morph.rb_closing(boxmap,(int(20*scale),int(scale)))
    cols = morph.select_regions(cols,sl.aspect,min=args.csminaspect)
    cols = morph.select_regions(cols,sl.dim0,min=args.csminheight*scale,nbest=args.maxcolseps)
    cols = morph.r_erosion(cols,(int(0.5+scale),0))
    cols = morph.r_dilation(cols,(int(0.5+scale),0),origin=(int(scale/2)-1,0))
    return cols


def compute_colseps_mconv(binary,scale=1.0):
    """Find column separators using a combination of morphological
    operations and convolution."""
    h,w = binary.shape
    smoothed = gaussian_filter(1.0*binary,(scale,scale*0.5))
    smoothed = uniform_filter(smoothed,(5.0*scale,1))
    thresh = (smoothed<np.amax(smoothed)*0.1)
    DSAVE("1thresh",thresh)
    blocks = morph.rb_closing(binary,(int(4*scale),int(4*scale)))
    DSAVE("2blocks",blocks)
    seps = np.minimum(blocks,thresh)
    seps = morph.select_regions(seps,sl.dim0,min=args.csminheight*scale,nbest=args.maxcolseps)
    DSAVE("3seps",seps)
    blocks = morph.r_dilation(blocks,(5,5))
    DSAVE("4blocks",blocks)
    seps = np.maximum(seps,1-blocks)
    DSAVE("5combo",seps)
    return seps


def compute_colseps_conv(binary,scale=1.0):
    """Find column separators by convolution and
    thresholding."""
    h,w = binary.shape
    # find vertical whitespace by thresholding
    smoothed = gaussian_filter(1.0*binary,(scale,scale*0.5))
    smoothed = uniform_filter(smoothed,(5.0*scale,1))
    thresh = (smoothed<np.amax(smoothed)*0.1)
    DSAVE("1thresh",thresh)
    # find column edges by filtering
    grad = gaussian_filter(1.0*binary,(scale,scale*0.5),order=(0,1))
    grad = uniform_filter(grad,(10.0*scale,1))
    # grad = abs(grad) # use this for finding both edges
    grad = (grad>0.5*np.amax(grad))
    DSAVE("2grad",grad)
    # combine edges and whitespace
    seps = np.minimum(thresh,maximum_filter(grad,(int(scale),int(5*scale))))
    seps = maximum_filter(seps,(int(2*scale),1))
    DSAVE("3seps",seps)
    # select only the biggest column separators
    seps = morph.select_regions(seps,sl.dim0,min=args.csminheight*scale,nbest=args.maxcolseps)
    DSAVE("4seps",seps)
    return seps


def compute_colseps(binary,scale):
    """Computes column separators either from vertical black lines or whitespace."""
    print_info("considering at most %g whitespace column separators" % args.maxcolseps)
    colseps = compute_colseps_conv(binary,scale)
    DSAVE("colwsseps",0.7*colseps+0.3*binary)
    if args.blackseps and args.maxseps == 0:
        # simulate old behaviour of blackseps when the default value
        # for maxseps was 2, but only when the maxseps-value is still zero
        # and not set manually to a non-zero value
        args.maxseps = 2
    if args.maxseps > 0:
        print_info("considering at most %g black column separators" % args.maxseps)
        seps = compute_separators_morph(binary,scale)
        DSAVE("colseps",0.7*seps+0.3*binary)
        #colseps = compute_colseps_morph(binary,scale)
        colseps = np.maximum(colseps,seps)
        binary = np.minimum(binary,1-seps)
    binary,colseps = apply_mask(binary,colseps)
    return colseps,binary

def apply_mask(binary,colseps):
    try:
        mask = ocrolib.read_image_binary(base+".mask.png")
    except IOError:
        return binary,colseps
    masked_seps = np.maximum(colseps,mask)
    binary = np.minimum(binary,1-masked_seps)
    DSAVE("masked_seps", masked_seps)
    return binary,masked_seps


################################################################
### Text Line Finding.
###
### This identifies the tops and bottoms of text lines by
### computing gradients and performing some adaptive thresholding.
### Those components are then used as seeds for the text lines.
################################################################

def compute_gradmaps(binary,scale):
    # use gradient filtering to find baselines
    boxmap = psegutils.compute_boxmap(binary,scale)
    cleaned = boxmap*binary
    DSAVE("cleaned",cleaned)
    if args.usegauss:
        # this uses Gaussians
        grad = gaussian_filter(1.0*cleaned,(args.vscale*0.3*scale,
                                            args.hscale*6*scale),order=(1,0))
    else:
        # this uses non-Gaussian oriented filters
        grad = gaussian_filter(1.0*cleaned,(max(4,args.vscale*0.3*scale),
                                            args.hscale*scale),order=(1,0))
        grad = uniform_filter(grad,(args.vscale,args.hscale*6*scale))
    bottom = ocrolib.norm_max((grad<0)*(-grad))
    top = ocrolib.norm_max((grad>0)*grad)
    return bottom,top,boxmap


def compute_line_seeds(binary,bottom,top,colseps,scale):
    """Base on gradient maps, computes candidates for baselines
    and xheights.  Then, it marks the regions between the two
    as a line seed."""
    t = args.threshold
    vrange = int(args.vscale*scale)
    bmarked = maximum_filter(bottom==maximum_filter(bottom,(vrange,0)),(2,2))
    bmarked = bmarked*(bottom>t*np.amax(bottom)*t)*(1-colseps)
    tmarked = maximum_filter(top==maximum_filter(top,(vrange,0)),(2,2))
    tmarked = tmarked*(top>t*np.amax(top)*t/2)*(1-colseps)
    tmarked = maximum_filter(tmarked,(1,20))
    seeds = np.zeros(binary.shape,'i')
    delta = max(3,int(scale/2))
    for x in range(bmarked.shape[1]):
        transitions = sorted([(y,1) for y in find(bmarked[:,x])]+[(y,0) for y in find(tmarked[:,x])])[::-1]
        transitions += [(0,0)]
        for l in range(len(transitions)-1):
            y0,s0 = transitions[l]
            if s0==0: continue
            seeds[y0-delta:y0,x] = 1
            y1,s1 = transitions[l+1]
            if s1==0 and (y0-y1)<5*scale: seeds[y1:y0,x] = 1
    seeds = maximum_filter(seeds,(1,int(1+scale)))
    seeds = seeds*(1-colseps)
    DSAVE("lineseeds",[seeds,0.3*tmarked+0.7*bmarked,binary])
    seeds,_ = morph.label(seeds)
    return seeds


################################################################
### The complete line segmentation process.
################################################################

def remove_hlines(binary,scale,maxsize=10):
    labels,_ = morph.label(binary)
    objects = morph.find_objects(labels)
    for i,b in enumerate(objects):
        if sl.width(b)>maxsize*scale:
            labels[b][labels[b]==i+1] = 0
    return np.array(labels!=0,'B')


def compute_segmentation(binary,scale):
    """Given a binary image, compute a complete segmentation into
    lines, computing both columns and text lines."""
    binary = np.array(binary,'B')

    # start by removing horizontal black lines, which only
    # interfere with the rest of the page segmentation
    binary = remove_hlines(binary,scale)

    # do the column finding
    if not args.quiet: print_info("computing column separators")
    colseps,binary = compute_colseps(binary,scale)

    # now compute the text line seeds
    if not args.quiet: print_info("computing lines")
    bottom,top,boxmap = compute_gradmaps(binary,scale)
    seeds = compute_line_seeds(binary,bottom,top,colseps,scale)
    DSAVE("seeds",[bottom,top,boxmap])

    # spread the text line seeds to all the remaining
    # components
    if not args.quiet: print_info("propagating labels")
    llabels = morph.propagate_labels(boxmap,seeds,conflict=0)
    if not args.quiet: print_info("spreading labels")
    spread = morph.spread_labels(seeds,maxdist=scale)
    llabels = np.where(llabels>0,llabels,spread*binary)
    segmentation = llabels*binary
    return segmentation


################################################################
### Processing each file.
################################################################

def process1(job):
    fname,i = job
    global base
    base,_ = ocrolib.allsplitext(fname)
    outputdir = base

    try:
        binary = ocrolib.read_image_binary(base+".bin.png")
    except IOError:
        try:
            binary = ocrolib.read_image_binary(fname)
        except IOError:
            if ocrolib.trace: traceback.print_exc()
            print_error("cannot open either %s.bin.png or %s" % (base, fname))
            return

    checktype(binary,ABINARY2)

    if not args.nocheck:
        check = check_page(np.amax(binary)-binary)
        if check is not None:
            print_error("%s SKIPPED %s (use -n to disable this check)" % (fname, check))
            return

    if args.gray:
        if os.path.exists(base+".nrm.png"):
            gray = ocrolib.read_image_gray(base+".nrm.png")
            checktype(gray,GRAYSCALE)
        else:
            print_error("Grayscale version %s.nrm.png not found. Use ocropus-nlbin for creating normalized grayscale version of the pages as well." % base)
            return

    binary = 1-binary # invert

    if args.scale==0:
        scale = psegutils.estimate_scale(binary)
    else:
        scale = args.scale
    print_info("scale %f" % (scale))
    if np.isnan(scale) or scale>1000.0:
        print_error("%s: bad scale (%g); skipping\n" % (fname, scale))
        return
    if scale<args.minscale:
        print_error("%s: scale (%g) less than --minscale; skipping\n" % (fname, scale))
        return

    # find columns and text lines

    if not args.quiet: print_info("computing segmentation")
    segmentation = compute_segmentation(binary,scale)
    if np.amax(segmentation)>args.maxlines:
        print_error("%s: too many lines %g" % (fname, np.amax(segmentation)))
        return
    if not args.quiet: print_info("number of lines %g" % np.amax(segmentation))

    # compute the reading order

    if not args.quiet: print_info("finding reading order")
    lines = psegutils.compute_lines(segmentation,scale)
    order = psegutils.reading_order([l.bounds for l in lines])
    lsort = psegutils.topsort(order)

    # renumber the labels so that they conform to the specs

    nlabels = np.amax(segmentation)+1
    renumber = np.zeros(nlabels,'i')
    for i,v in enumerate(lsort): renumber[lines[v].label] = 0x010000+(i+1)
    segmentation = renumber[segmentation]

    # finally, output everything

    if not args.quiet: print_info("writing lines")
    if not os.path.exists(outputdir):
        os.mkdir(outputdir)
    lines = [lines[i] for i in lsort]
    ocrolib.write_page_segmentation("%s.pseg.png"%outputdir,segmentation)
    cleaned = ocrolib.remove_noise(binary,args.noise)
    for i,l in enumerate(lines):
        binline = psegutils.extract_masked(1-cleaned,l,pad=args.pad,expand=args.expand)
        ocrolib.write_image_binary("%s/01%04x.bin.png"%(outputdir,i+1),binline)
        if args.gray:
            grayline = psegutils.extract_masked(gray,l,pad=args.pad,expand=args.expand)
            ocrolib.write_image_gray("%s/01%04x.nrm.png"%(outputdir,i+1),grayline)
    print_info("%6d  %s %4.1f %d" % (i, fname,  scale,  len(lines)))

if len(args.files)==1 and os.path.isdir(args.files[0]):
    files = glob.glob(args.files[0]+"/????.png")
else:
    files = args.files


def safe_process1(job):
    fname,i = job
    try:
        process1(job)
    except OcropusException as e:
        if e.trace:
            traceback.print_exc()
        else:
            print_info(fname+":"+e)
    except Exception as e:
        traceback.print_exc()

if args.parallel<2:
    count = 0
    for i,f in enumerate(files):
        if args.parallel==0: print_info(f)
        count += 1
        safe_process1((f,i+1))
else:
    pool = Pool(processes=args.parallel)
    jobs = []
    for i,f in enumerate(files): jobs += [(f,i+1)]
    result = pool.map(process1,jobs)
