"""
Adapters and other Construct utility classes
"""
import json
import construct
from io import BytesIO
from enum import Enum
from uuid import UUID
from cryptography import x509
from cryptography.x509.oid import NameOID
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.serialization import Encoding

from xbox.sg.enum import PacketType


class CryptoTunnel(construct.Subconstruct):
    """
    Adapter/Tunnel for inline decryption of protected payloads.

    Depending on the packet-type, acquiring the `Initialization Vector` and
    decrypting the `protected payload` happens differently:
    * ConnectResponse: The `IV` is delivered in the `unprotected payload`
    section of the packet and can be used directly
    * Messages: The `IV` is generated by encrypting the first 16 bytes of
    the packet header using the IV key.

    Inline encryption using this Tunnel is not (yet) possible due to
    limitations in Construct. As explained above, the `IV` for Messages is
    generated by encrypting the first 16 bytes of the header. However, there's
    currently no method of determining the length of the payload without
    building it first. Therefor, the `IV` wouldn't be correct.
    """
    def _parse(self, stream, context, path):
        return self.subcon._parse(self.decrypt(stream, context), context, path)

    @staticmethod
    def decrypt(stream, context):
        # To simplify message definition, we allow Switch to be a subcon
        # of CryptoAdapter
        # However, as a side effect, this function will also be called when the
        # stream is EOF (has no protected payload)
        # To compensate for this, we check if there's at least a hash + 1 block
        # left in the stream
        pos = stream.tell()
        if len(stream.read(48)) != 48:
            return None

        crypto = context.get('_crypto', None) or context._.get('_crypto', None)

        if not crypto:
            raise ValueError("Crypto instance not passed in context")

        stream.seek(0)
        buf = stream.read()

        if not crypto.verify(buf[:-32], buf[-32:]):
            raise ValueError("Checksum doesn't match")

        connect_types = [PacketType.ConnectRequest, PacketType.ConnectResponse]
        if context.header.pkt_type in connect_types:
            iv = context.unprotected_payload.iv
        elif context.header.pkt_type == PacketType.Message:
            iv = crypto.generate_iv(buf[:16])
        else:
            raise ValueError("Incompatible packet type")

        buf = buf[pos:]
        decrypted = crypto.decrypt(iv, buf)
        decrypted = decrypted[:context.header.protected_payload_length]

        return BytesIO(decrypted)

    def _emitparse(self, code):
        # Hack
        code.append('from xbox.sg.utils.adapters import CryptoTunnel')
        subcode = self.subcon._emitparse(code)
        return subcode.replace('(io', '(CryptoTunnel.decrypt(io, this)')


class JsonAdapter(construct.Adapter):
    """
    Construct-Adapter for JSON field.

    Parses and dumps JSON.
    """
    def _encode(self, obj, context, path):
        if not isinstance(obj, dict):
            raise TypeError('Object not of type dict')
        return json.dumps(obj, separators=(',', ':'), sort_keys=True)

    def _decode(self, obj, context, path):
        return json.loads(obj)

    def _emitparse(self, code):
        code.append('import json')
        return 'json.loads({})'.format(self.subcon._emitparse(code))


class UUIDAdapter(construct.Adapter):
    def __init__(self, encoding=None):
        """
        Construct-Adapter for UUID field.

        Parses either `utf8` encoded or raw byte strings into :class:`UUID`
        instances.

        Args:
            encoding (str): The encoding to use.
        """
        if encoding:
            super(self.__class__, self).__init__(SGString(encoding))
        else:
            super(self.__class__, self).__init__(construct.Bytes(0x10))

    def _encode(self, obj, context, path):
        if not isinstance(obj, UUID):
            raise TypeError('Object not of type UUID')
        if isinstance(self.subcon, construct.Bytes):
            return obj.bytes
        else:
            return str(obj).upper()

    def _decode(self, obj, context, path):
        if isinstance(self.subcon, construct.Bytes):
            return UUID(bytes=obj)
        else:
            return UUID(obj)

    def _emitparse(self, code):
        code.append('from uuid import UUID')
        subcon_code = self.subcon._emitparse(code)
        if isinstance(self.subcon, construct.Bytes):
            return 'UUID(bytes={})'.format(subcon_code)
        return 'UUID({})'.format(subcon_code)


class CertificateAdapter(construct.Adapter):
    def __init__(self):
        """
        Construct-Adapter for Certificate field.

        Parses and dumps the DER certificate as used in the discovery response
        messages.
        """
        super(self.__class__, self).__init__(PrefixedBytes(construct.Int16ub))

    def _encode(self, obj, context, path):
        if not isinstance(obj, CertificateInfo):
            raise TypeError('Object not of type CertificateInfo')
        return obj.dump()

    def _decode(self, obj, context, path):
        return CertificateInfo(obj)

    def _emitparse(self, code):
        code.append('from xbox.sg.utils.adapters import CertificateInfo')
        return 'CertificateInfo({})'.format(self.subcon._emitparse(code))


class CertificateInfo(object):
    def __init__(self, raw_cert):
        """
        Helper class for parsing a x509 certificate.

        Extracts `common_name` and `public_key` from the certificate.

        Args:
            raw_cert (bytes): The DER certificate to parse.
        """
        self.cert = x509.load_der_x509_certificate(raw_cert, default_backend())
        self.liveid = self.cert.subject.get_attributes_for_oid(
            NameOID.COMMON_NAME)[0].value
        self.pubkey = self.cert.public_key()

    def dump(self, encoding=Encoding.DER):
        return self.cert.public_bytes(encoding)

    def __repr__(self):
        return '<%s: liveid=%s, pubkey=%s>' % (
            self.__class__.__name__, self.liveid, self.pubkey
        )

    def __eq__(self, other):
        return self.cert == other.cert


class XSwitch(construct.Switch):
    def _emitparse(self, code):
        if not any([isinstance(k, Enum) for k in self.cases.keys()]):
            return super(XSwitch, self)._emitparse(code)

        fname = "factory_%s" % code.allocateId()
        code.append("%s = {%s}" % (fname, ", ".join("%r : lambda io,this: %s" % (key.value, sc._compileparse(code)) for key, sc in self.cases.items()), ))

        defaultfname = "compiled_%s" % code.allocateId()
        code.append("%s = lambda io,this: %s" % (defaultfname, self.default._compileparse(code), ))
        return "%s.get(%s.value, %s)(io, this)" % (fname, self.keyfunc, defaultfname)


class XEnum(construct.Adapter):
    def __init__(self, subcon, enum=None):
        """
        Construct-Adapter for Enum field.

        Parses numeric fields into `XEnumInt`'s, which display the Enum name and value.

        Args:
            subcon (Construct): The subcon to adapt.
            enum (Enum): The enum to parse into.
        """
        super(XEnum, self).__init__(subcon)
        self.enum = enum

    def _encode(self, obj, context, path):
        if isinstance(obj, int):
            return obj
        return obj.value

    def _decode(self, obj, context, path):
        if not self.enum:
            return obj
        return self.enum(obj)

    def _emitparse(self, code):
        if not self.enum:
            return self.subcon._emitparse(code)

        code.append('from {} import {}'.format(self.enum.__module__, self.enum.__name__))
        return '{}({})'.format(self.enum.__name__, self.subcon._emitparse(code))


class XInject(construct.Construct):
    def __init__(self, code):
        super(XInject, self).__init__()
        self.flagbuildnone = True
        self.name = '__inject'
        self.code = code

    def _parse(self, stream, context, path):
        pass

    def _build(self, obj, stream, context, path):
        pass

    def _sizeof(self, context, path):
        return 0

    def _emitparse(self, code):
        code.append(self.code)


class TerminatedField(construct.Subconstruct):
    def __init__(self, subcon, length=1, pattern=b'\x00'):
        """
        A custom :class:`Subconstruct` that adds a termination character at
        the end of the child struct.

        Args:
            subcon (Construct): The subcon to add the terminated character to.
            length (int): The amount of termination characters to add.
            pattern (bytes): The termination pattern to use.
        """
        super(self.__class__, self).__init__(subcon)
        self.padding = construct.Padding(length, pattern)

    def _parse(self, stream, context, path):
        obj = self.subcon._parse(stream, context, path)
        self.padding._parse(stream, context, path)

        return obj

    def _build(self, obj, stream, context, path):
        subobj = self.subcon._build(obj, stream, context, path)
        self.padding._build(obj, stream, context, path)

        return subobj

    def _sizeof(self, context, path):
        return self.subcon._sizeof(context, path) + 1

    def _emitparse(self, code):
        return '({}, {})[0]'.format(
            self.subcon._emitparse(code), self.padding._emitparse(code)
        )


def SGString(encoding='utf8'):
    """
    Defines a null terminated `PascalString`.

    The Smartglass protocol seems to always add a termination character
    behind a length prefixed string.
    This utility function combines a `PascalString` with a `TerminatedField`.

    Args:
        encoding (str): The string encoding to use for the `PascalString`.

    Returns:
        SGString: A null byte terminated `PascalString`.
    """
    return TerminatedField(construct.PascalString(construct.Int16ub, encoding))


def PrefixedBytes(lengthfield):
    """
    Defines a length prefixed bytearray.

    Args:
        lengthfield (:class:`Subconstruct`): The length subcon.

    Returns:
        PrefixedBytes: A length prefixed bytesarray
    """
    return construct.Prefixed(lengthfield, construct.GreedyBytes)


class FieldIn(object):
    """
    Helper class for creating an `in` conditional.

    Operates like `field in options`.

    Args:
        field: The struct field to use.
        options (list): A list with options to execute the `in` conditional on.

    Returns:
        bool: Whether or not the field value was in the options.
    """
    def __init__(self, field, options):
        self.field = field
        self.options = options

    def __call__(self, context):
        return context.get(self.field, None) in self.options

    def __repr__(self):
        if any([isinstance(e, Enum) for e in self.options]):
            return 'this.{}.value in {}'.format(self.field, [e.value for e in self.options])
        return 'this.{} in {}'.format(self.field, self.options)
