#!/usr/bin/env python
"""
Compare two images by PSNR. Return 0 if they are matching or 1 if they are
not matching.
"""
import sys
import getopt
from PIL import Image
import numpy as np


program_name = None


def compare(filename1, filename2):
    im1 = Image.open(filename1)
    data1 = np.array(im1)

    im2 = Image.open(filename2)
    data2 = np.array(im2)

    w1, h1 = data1.shape[0], data1.shape[1]
    w2, h2 = data2.shape[0], data2.shape[1]

    w = max(w1, w2)
    h = max(h1, h2)

    d1 = np.zeros((w,h,3))
    d2 = np.zeros((w,h,3))

    d1[:w1,:h1,:] = data1[:,:,0:3]
    d2[:w2,:h2,:] = data2[:,:,0:3]

    mse = np.sum((d2 - d1)**2)/(w*h*3)
    if mse == 0:
        return np.inf

    psnr = 10*np.log10(255**2/mse)

    return psnr


def usage():
    print >> sys.stderr, '''\
Usage: {program_name} FILE1 FILE2
Try `{program_name} --help' for more information.\
'''.format(program_name=program_name)


def help():
    print '''\
Usage: {program_name} [-p] [-t THRESHOLD] FILE1 FILE2

Comapare image FILE1 to image FILE2 by PSNR.

Optional arguments:
  -t THRESHOLD          PSNR threshold (default: 20)
  -p                    print PSNR\
'''.format(program_name=program_name)


if __name__ == '__main__':
    program_name = sys.argv[0]

    try:
        opts, args = getopt.getopt(sys.argv[1:], 't:p', ['help'])
    except getopt.GetoptError as err:
        usage()
        sys.exit(2)

    print_psnr = False
    threshold = 20

    for opt, arg in opts:
        if opt == '-t':
            try:
                threshold = int(arg)
            except ValueError:
                print >>sys.stderr, 'Invalid value for %s: %s' % (opt, arg)
                sys.exit(2)
        if opt == '-p':
            print_psnr = True
        if opt == '--help':
            help()
            sys.exit(0)

    if len(args) != 2:
        usage()
        sys.exit(2)

    filename1 = args[0]
    filename2 = args[1]

    try:
        psnr = compare(filename1, filename2)
        result = psnr >= threshold
        if print_psnr:
            print psnr
        sys.exit(0 if result else 1)
    except IOError as e:
        print >> sys.stderr, '%s: %s' % (e.filename, e.strerror)
        sys.exit(2)
