# Copyright (C) 2015-2018  The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information

import base64
import datetime
from json import JSONDecoder, JSONEncoder
import types
from uuid import UUID

import arrow
import dateutil.parser
import msgpack


def encode_data_client(data):
    try:
        return msgpack_dumps(data)
    except OverflowError as e:
        raise ValueError('Limits were reached. Please, check your input.\n' +
                         str(e))


def decode_response(response):
    content_type = response.headers['content-type']

    if content_type.startswith('application/x-msgpack'):
        r = msgpack_loads(response.content)
    elif content_type.startswith('application/json'):
        r = response.json(cls=SWHJSONDecoder)
    else:
        raise ValueError('Wrong content type `%s` for API response'
                         % content_type)

    return r


class SWHJSONEncoder(JSONEncoder):
    """JSON encoder for data structures generated by Software Heritage.

    This JSON encoder extends the default Python JSON encoder and adds
    awareness for the following specific types:

    - bytes (get encoded as a Base85 string);
    - datetime.datetime (get encoded as an ISO8601 string).

    Non-standard types get encoded as a a dictionary with two keys:

    - swhtype with value 'bytes' or 'datetime';
    - d containing the encoded value.

    SWHJSONEncoder also encodes arbitrary iterables as a list
    (allowing serialization of generators).

    Caveats: Limitations in the JSONEncoder extension mechanism
    prevent us from "escaping" dictionaries that only contain the
    swhtype and d keys, and therefore arbitrary data structures can't
    be round-tripped through SWHJSONEncoder and SWHJSONDecoder.

    """

    def default(self, o):
        if isinstance(o, bytes):
            return {
                'swhtype': 'bytes',
                'd': base64.b85encode(o).decode('ascii'),
            }
        elif isinstance(o, datetime.datetime):
            return {
                'swhtype': 'datetime',
                'd': o.isoformat(),
            }
        elif isinstance(o, UUID):
            return {
                'swhtype': 'uuid',
                'd': str(o),
            }
        elif isinstance(o, datetime.timedelta):
            return {
                'swhtype': 'timedelta',
                'd': {
                    'days': o.days,
                    'seconds': o.seconds,
                    'microseconds': o.microseconds,
                },
            }
        elif isinstance(o, arrow.Arrow):
            return {
                'swhtype': 'arrow',
                'd': o.isoformat(),
            }
        try:
            return super().default(o)
        except TypeError as e:
            try:
                iterable = iter(o)
            except TypeError:
                raise e from None
            else:
                return list(iterable)


class SWHJSONDecoder(JSONDecoder):
    """JSON decoder for data structures encoded with SWHJSONEncoder.

    This JSON decoder extends the default Python JSON decoder,
    allowing the decoding of:

    - bytes (encoded as a Base85 string);
    - datetime.datetime (encoded as an ISO8601 string).

    Non-standard types must be encoded as a a dictionary with exactly
    two keys:

    - swhtype with value 'bytes' or 'datetime';
    - d containing the encoded value.

    To limit the impact our encoding, if the swhtype key doesn't
    contain a known value, the dictionary is decoded as-is.

    """
    def decode_data(self, o):
        if isinstance(o, dict):
            if set(o.keys()) == {'d', 'swhtype'}:
                datatype = o['swhtype']
                if datatype == 'bytes':
                    return base64.b85decode(o['d'])
                elif datatype == 'datetime':
                    return dateutil.parser.parse(o['d'])
                elif datatype == 'uuid':
                    return UUID(o['d'])
                elif datatype == 'timedelta':
                    return datetime.timedelta(**o['d'])
                elif datatype == 'arrow':
                    return arrow.get(o['d'])
            return {key: self.decode_data(value) for key, value in o.items()}
        if isinstance(o, list):
            return [self.decode_data(value) for value in o]
        else:
            return o

    def raw_decode(self, s, idx=0):
        data, index = super().raw_decode(s, idx)
        return self.decode_data(data), index


def msgpack_dumps(data):
    """Write data as a msgpack stream"""
    def encode_types(obj):
        if isinstance(obj, datetime.datetime):
            return {b'__datetime__': True, b's': obj.isoformat()}
        if isinstance(obj, types.GeneratorType):
            return list(obj)
        if isinstance(obj, UUID):
            return {b'__uuid__': True, b's': str(obj)}
        if isinstance(obj, datetime.timedelta):
            return {
                b'__timedelta__': True,
                b's': {
                    'days': obj.days,
                    'seconds': obj.seconds,
                    'microseconds': obj.microseconds,
                },
            }
        if isinstance(obj, arrow.Arrow):
            return {b'__arrow__': True, b's': obj.isoformat()}
        return obj

    return msgpack.packb(data, use_bin_type=True, default=encode_types)


def msgpack_loads(data):
    """Read data as a msgpack stream"""
    def decode_types(obj):
        if b'__datetime__' in obj and obj[b'__datetime__']:
            return dateutil.parser.parse(obj[b's'])
        if b'__uuid__' in obj and obj[b'__uuid__']:
            return UUID(obj[b's'])
        if b'__timedelta__' in obj and obj[b'__timedelta__']:
            return datetime.timedelta(**obj[b's'])
        if b'__arrow__' in obj and obj[b'__arrow__']:
            return arrow.get(obj[b's'])
        return obj

    try:
        return msgpack.unpackb(data, raw=False,
                               object_hook=decode_types)
    except TypeError:  # msgpack < 0.5.2
        return msgpack.unpackb(data, encoding='utf-8',
                               object_hook=decode_types)
