#!/usr/bin/env python

from __future__ import print_function

import argparse
import os
import multiprocessing
import sys

import numpy as np
import matplotlib.pyplot as plt
# sub-namespaces have been deprecated, but we like to use them anyway
from scipy import ndimage as filters
from scipy import ndimage as interpolation
from scipy import ndimage as morphology
from scipy import ndimage as measurements
from scipy import stats

import ocrolib


parser = argparse.ArgumentParser("""
Image binarization using non-linear processing.

This is a compute-intensive binarization method that works on degraded
and historical book pages.
""")

parser.add_argument('-n','--nocheck',action="store_true",
                    help="disable error checking on inputs")
parser.add_argument('-t','--threshold',type=float,default=0.5,help='threshold, determines lightness, default: %(default)s')
parser.add_argument('-z','--zoom',type=float,default=0.5,help='zoom for page background estimation, smaller=faster, default: %(default)s')
parser.add_argument('-e','--escale',type=float,default=1.0,help='scale for estimating a mask over the text region, default: %(default)s')
parser.add_argument('-b','--bignore',type=float,default=0.1,help='ignore this much of the border for threshold estimation, default: %(default)s')
parser.add_argument('-p','--perc',type=float,default=80,help='percentage for filters, default: %(default)s')
parser.add_argument('-r','--range',type=int,default=20,help='range for filters, default: %(default)s')
parser.add_argument('-m','--maxskew',type=float,default=2,help='skew angle estimation parameters (degrees), default: %(default)s')
parser.add_argument('-g','--gray',action='store_true',help='force grayscale processing even if image seems binary')
parser.add_argument('--lo',type=float,default=5,help='percentile for black estimation, default: %(default)s')
parser.add_argument('--hi',type=float,default=90,help='percentile for white estimation, default: %(default)s')
parser.add_argument('--skewsteps',type=int,default=8,help='steps for skew angle estimation (per degree), default: %(default)s')
parser.add_argument('--debug',type=float,default=0,help='display intermediate results, default: %(default)s')
parser.add_argument('--show',action='store_true',help='display final result')
parser.add_argument('--rawcopy',action='store_true',help='also copy the raw image')
parser.add_argument('-o','--output',default=None,help="output directory")
parser.add_argument('files',nargs='+')
parser.add_argument('-Q','--parallel',type=int,default=0)
args = parser.parse_args()

args.files = ocrolib.glob_all(args.files)

if len(args.files)<1:
    parser.print_help()
    sys.exit(0)


def print_info(*objs):
    print("INFO: ", *objs, file=sys.stdout)


def print_error(*objs):
    print("ERROR: ", *objs, file=sys.stderr)


def check_page(image):
    if len(image.shape)==3: return "input image is color image %s"%(image.shape,)
    if np.mean(image)<np.median(image): return "image may be inverted"
    h,w = image.shape
    if h<600: return "image not tall enough for a page image %s"%(image.shape,)
    if h>10000: return "image too tall for a page image %s"%(image.shape,)
    if w<600: return "image too narrow for a page image %s"%(image.shape,)
    if w>10000: return "line too wide for a page image %s"%(image.shape,)
    return None

def estimate_skew_angle(image,angles):
    estimates = []
    for a in angles:
        v = np.mean(interpolation.rotate(image,a,order=0,mode='constant'),axis=1)
        v = np.var(v)
        estimates.append((v,a))
    if args.debug>0:
        plt.plot([y for x,y in estimates],[x for x,y in estimates])
        plt.ginput(1,args.debug)
    _,a = max(estimates)
    return a


def H(s): return s[0].stop-s[0].start
def W(s): return s[1].stop-s[1].start
def A(s): return W(s)*H(s)


def dshow(image,info):
    if args.debug<=0: return
    plt.ion()
    plt.gray()
    plt.imshow(image)
    plt.title(info)
    plt.ginput(1,args.debug)


def normalize_raw_image(raw):
    ''' perform image normalization '''
    image = raw-np.amin(raw)
    if np.amax(image)==np.amin(image):
        print_info("# image is empty: %s" % (fname))
        return None
    image /= np.amax(image)
    return image


def estimate_local_whitelevel(image, zoom=0.5, perc=80, range=20, debug=0):
    '''flatten it by estimating the local whitelevel
    zoom for page background estimation, smaller=faster, default: %(default)s
    percentage for filters, default: %(default)s
    range for filters, default: %(default)s
    '''
    m = interpolation.zoom(image,zoom)
    m = filters.percentile_filter(m,perc,size=(range,2))
    m = filters.percentile_filter(m,perc,size=(2,range))
    m = interpolation.zoom(m,1.0/zoom)
    if debug>0:
        plt.clf()
        plt.imshow(m,vmin=0,vmax=1)
        plt.ginput(1,debug)
    w,h = np.minimum(np.array(image.shape),np.array(m.shape))
    flat = np.clip(image[:w,:h]-m[:w,:h]+1,0,1)
    if debug>0:
        plt.clf()
        plt.imshow(flat,vmin=0,vmax=1)
        plt.ginput(1,debug)
    return flat


def estimate_skew(flat, bignore=0.1, maxskew=2, skewsteps=8):
    ''' estimate skew angle and rotate'''
    d0,d1 = flat.shape
    o0,o1 = int(bignore*d0),int(bignore*d1) # border ignore
    flat = np.amax(flat)-flat
    flat -= np.amin(flat)
    est = flat[o0:d0-o0,o1:d1-o1]
    ma = maxskew
    ms = int(2*maxskew*skewsteps)
    # print(linspace(-ma,ma,ms+1))
    angle = estimate_skew_angle(est,np.linspace(-ma,ma,ms+1))
    flat = interpolation.rotate(flat,angle,mode='constant',reshape=0)
    flat = np.amax(flat)-flat
    return flat, angle



def estimate_thresholds(flat, bignore=0.1, escale=1.0, lo=5, hi=90, debug=0):
    '''# estimate low and high thresholds
    ignore this much of the border for threshold estimation, default: %(default)s
    scale for estimating a mask over the text region, default: %(default)s
    lo percentile for black estimation, default: %(default)s
    hi percentile for white estimation, default: %(default)s
    '''
    d0,d1 = flat.shape
    o0,o1 = int(bignore*d0),int(bignore*d1)
    est = flat[o0:d0-o0,o1:d1-o1]
    if escale>0:
        # by default, we use only regions that contain
        # significant variance; this makes the percentile
        # based low and high estimates more reliable
        e = escale
        v = est-filters.gaussian_filter(est,e*20.0)
        v = filters.gaussian_filter(v**2,e*20.0)**0.5
        v = (v>0.3*np.amax(v))
        v = morphology.binary_dilation(v,structure=np.ones((int(e*50),1)))
        v = morphology.binary_dilation(v,structure=np.ones((1,int(e*50))))
        if debug>0:
            plt.imshow(v)
            plt.ginput(1,debug)
        est = est[v]
    lo = stats.scoreatpercentile(est.ravel(),lo)
    hi = stats.scoreatpercentile(est.ravel(),hi)
    return lo, hi


def process1(job):
    fname,i = job
    print_info("# %s" % (fname))
    if args.parallel<2: print_info("=== %s %-3d" % (fname, i))
    raw = ocrolib.read_image_gray(fname)
    dshow(raw,"input")
    # perform image normalization
    image = normalize_raw_image(raw)

    if not args.nocheck:
        check = check_page(np.amax(image)-image)
        if check is not None:
            print_error(fname+"SKIPPED"+check+"(use -n to disable this check)")
            return

    # check whether the image is already effectively binarized
    if args.gray:
        extreme = 0
    else:
        extreme = (np.sum(image<0.05)+np.sum(image>0.95))*1.0/np.prod(image.shape)
    if extreme>0.95:
        comment = "no-normalization"
        flat = image
    else:
        comment = ""
        # if not, we need to flatten it by estimating the local whitelevel
        if args.parallel<2: print_info("flattening")
        flat = estimate_local_whitelevel(image, args.zoom, args.perc, args.range, args.debug)

    # estimate skew angle and rotate
    if args.maxskew>0:
        if args.parallel<2: print_info("estimating skew angle")
        flat, angle = estimate_skew(flat, args.bignore, args.maxskew, args.skewsteps)
    else:
        angle = 0

    # estimate low and high thresholds
    if args.parallel<2: print_info("estimating thresholds")
    lo, hi = estimate_thresholds(flat, args.bignore, args.escale, args.lo, args.hi, args.debug)
    # rescale the image to get the gray scale image
    if args.parallel<2: print_info("rescaling")
    flat -= lo
    flat /= (hi-lo)
    flat = np.clip(flat,0,1)
    if args.debug>0:
        plt.imshow(flat,vmin=0,vmax=1)
        plt.ginput(1,args.debug)
    bin = 1*(flat>args.threshold)

    # output the normalized grayscale and the thresholded images
    print_info("%s lo-hi (%.2f %.2f) angle %4.1f %s" % (fname, lo, hi, angle, comment))
    if args.parallel<2: print_info("writing")
    if args.debug>0 or args.show:
        plt.clf()
        plt.gray()
        plt.imshow(bin)
        plt.ginput(1,max(0.1,args.debug))
    if args.output:
        if args.rawcopy: ocrolib.write_image_gray(args.output+"/%04d.raw.png"%i,raw)
        ocrolib.write_image_binary(args.output+"/%04d.bin.png"%i,bin)
        ocrolib.write_image_gray(args.output+"/%04d.nrm.png"%i,flat)
    else:
        base,_ = ocrolib.allsplitext(fname)
        ocrolib.write_image_binary(base+".bin.png",bin)
        ocrolib.write_image_gray(base+".nrm.png",flat)

if args.debug>0 or args.show>0: args.parallel = 0

if args.output:
    if not os.path.exists(args.output):
        os.mkdir(args.output)

if args.parallel<2:
    for i,f in enumerate(args.files):
        process1((f,i+1))
else:
    pool = multiprocessing.Pool(processes=args.parallel)
    jobs = []
    for i,f in enumerate(args.files): jobs += [(f,i+1)]
    result = pool.map(process1,jobs)
