#!/usr/bin/env python
from __future__ import print_function, unicode_literals
import json, sys, os, argparse, itertools, re, shutil
from contextlib import closing

import json_delta

try:
    import msvcrt
    msvcrt.setmode(sys.stdin.fileno(), os.O_BINARY)
    msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)
except ImportError:
    pass

HEADER_PATTERN = re.compile(
    r'''
        ^(?P<side>[+-]{3})\s
         (?P<filename>.+)\s
         (?P<timestamp>\d{4}-\d{2}-\d{2}\s
                       \d{2}:\d{2}:\d{2}(?:\.\d+)?\s
                       .+)$
    ''',
    re.MULTILINE | re.VERBOSE)


def udiff_with_headers(diff):
    '''If diff can be read as a unified diff with headers, return a
    sequence of dictionaries parsing the header lines.  Otherwise
    return False.'''
    first_line = HEADER_PATTERN.match(diff)
    if first_line is None or first_line.group('side') != '---':
        return False
    left = first_line.groupdict()
    del left['side']
    second_line = HEADER_PATTERN.match(diff[first_line.end()+1:])
    if second_line is None or second_line.group('side') != '+++':
        return False
    right = second_line.groupdict()
    del right['side']
    return (left, right)

def is_udiff(diff):
    '''Boolean test whether diff can be read as a unified diff.'''
    point = 0
    if diff and diff[point] not in ('+', '-', ' ', '\n'):
        return False
    while point < len(diff) - 1:
        if diff[point] == '\n':
            if diff[point+1] not in ('+', '-', ' ', '\n'):
                return False
        point += 1
    return True

def parse_normal(diff):
    '''If diff can be read as a normal (JSON-serialized) diff, decode
    and return it.  Otherwise return False.'''
    try:
        diff = json_delta._util.decode_json(diff)
    except ValueError:
        return False
    return json_delta._util.check_diff_structure(diff)

def diff_if_any(diff, ns, norm_check=parse_normal):
    '''Return diff if it is valid, and update the namespace ns
    appropriately.'''
    if ns.normal or not ns.unified:
        norm = norm_check(diff)
        if norm != False:
            if ns.reverse:
                print('Cannot reverse JSON-format patches.', file=sys.stderr)
                raise Exception
            if norm is True:
                norm = []
            ns.normal = True
            ns.unified = False
            return json.dumps(norm)
    if isinstance(diff, bytes):
        diff = json_delta._util.decode_udiff(diff)
    udiff_parse = udiff_with_headers(diff)
    if udiff_parse:
        update_namespace_headers(ns, udiff_parse)
        return json.dumps(diff)
    elif is_udiff(diff):
        ns.unified = True
        ns.diff = diff
        if ns.output is None:
            ns.output = sys.stdout
        return json.dumps(diff)
    print("Cannot parse diff.", file=sys.stderr)
    raise Exception

def update_namespace_headers(ns, udiff_parse):
    ns.unified = True
    left, right = udiff_parse
    if ns.struc is None:
        try:
            with open(struc_path) as f:
                ns.struc = json.load(f)
        except IOError:
            print(("Couldn't read specified file %s" % struc_path), 
                  file=sys.stderr)
            raise
    if ns.output is None:
        ns.output = struc_path

def infer(ns):
    if ns.struc is None and ns.diff is None:
        stdin = json_delta._util.read_bytestring(sys.stdin)
        try:
            udiff_parse = udiff_with_headers(json_delta._util.decode_udiff(stdin))
        except (ValueError, UnicodeError):
            udiff_parse = None

        if udiff_parse:
            diff = stdin
            update_namespace_headers(ns, udiff_parse)
            struc = ns.struc
        else:
            try:
                [struc, diff] = json_delta._util.decode_json(stdin)
            except:
                print("Couldn't read structure and diff from standard input",
                      file=sys.stderr)
                raise
            ns.diff = diff_if_any(
                diff, ns, json_delta._util.check_diff_structure
            )
            struc = None
            diff = None
            both = stdin
    elif ns.diff is None:
        struc = ns.struc
        diff = diff_if_any(json_delta._util.read_bytestring(sys.stdin), ns)
        both = None
    else:
        struc = ns.struc
        diff = diff_if_any(json_delta._util.read_bytestring(ns.diff), ns)
        both = None
    return struc, diff, both

def trim_path(path, slashes):
    '''Trim the shortest prefix containing a number of path separators
    equal to slashes from path.

    This function is supposed to mimic the behaviour of the -p option
    to patch: sequences of adjacent slashes are treated as one slash,
    and if there are not enough slashes in path, it is returned
    unchanged (but in a normalized form).

    (Note: the following doctests will fail on a platform where os.sep
    != '/'!)
    
    >>> trim_path('/foo/bar/baz', 1)
    'foo/bar/baz'
    >>> trim_path('/foo/bar/baz', 2)
    'bar/baz'
    >>> trim_path('/foo/bar/baz', 3)
    'baz'
    >>> trim_path('/foo//bar/baz', 4)
    '/foo/bar/baz'
    '''
    path = os.path.normpath(path)
    elems = path.split(os.sep)
    if len(elems) <= slashes:
        return os.sep.join(elems)
    return os.sep.join(elems[slashes:])
        
def main():
    parser = argparse.ArgumentParser(
        prog='json_patch',
        description=
        '''Applies a patch of the form generated by json_diff to a
        JSON-serialized data structure.  If no arguments are
        specified, stdin will be expected to be a JSON structure: this
        could be a pair [struc, diff], in which case the output will
        be written to stdout.  Or it could be a unified diff with
        headers, as output by json_delta -u.  In this case, the
        program reads the name of the file to patch against out of the
        headers.  Output, however, is always written to the file
        specified by the '--output' option (stdout by default).
        ''',
        epilog='Report bugs to <himself@phil-roberts.name>'
        )
    parser.add_argument('struc', nargs='?', type=argparse.FileType('rb'),
                        help='Filename for the structure to modify.')
    parser.add_argument(
        'diff', nargs='?', type=argparse.FileType('rb'),
        help=('Filename for the diff to apply.  Will read from stdin '
              'by default.')
        )
    parser.add_argument(
        '--output', '-o', type=argparse.FileType('wb'),
        help='Filename to output to.  Standard output is the default.',
        metavar='FILE', default=sys.stdout
        )
    parser.add_argument(
        '--strip', '-p', type=int, metavar='NUM', default=0,
        help='Strip NUM leading components from file names.'
        )
    parser.add_argument(
        '--reverse', '-R', action='store_true',
        help="Assume patches were created with old and new files swapped.")
    version = '''%(prog)s - part of json-delta {}
Copyright 2012-2015 Philip J. Roberts <himself@phil-roberts.name>. 
BSD License applies; see http://opensource.org/licenses/BSD-2-Clause
'''.format(json_delta.__VERSION__)
    parser.add_argument(
        '--encoding', help='Select encoding for output (default UTF-8).',
        default='UTF-8'
    )
    parser.add_argument(
        '--version', action='version', version=version
        )

    format_grp = parser.add_mutually_exclusive_group()
    format_grp.add_argument(
        '--unified', '-u', action='store_true',
        help=("Interpret the patch as a unified difference "
              "(As emitted by 'json_diff -u').")
    )
    format_grp.add_argument(
        '--normal', '-n', action='store_true',
        help=("Interpret the patch as a normal (i.e. JSON-serialized) "
              "difference.")
    )
    ns = parser.parse_args()
    struc, diff, both = infer(ns)
    with closing(ns.output) as out_f:
        out_f = getattr(out_f, 'buffer', out_f)
        if ns.unified:
            output = json_delta.load_and_upatch(struc, diff, both,
                                                reverse=ns.reverse)
        else:
            output = json_delta.load_and_patch(struc, diff, both)
        output = json_delta._util.compact_json_dumps(output)
        out_f.write(output.encode(ns.encoding))
    return 0

if __name__ == '__main__':
    sys.exit(main())
