"""Bitarray based codec for binary data structures."""

import struct

from typing import Any, Callable, Optional

import bitarray
from bitarray.util import ba2int

import bpack
import bpack.utils

from .codec_utils import make_decoder_decorator
from .descriptors import field_descriptors


__all__ = ['Decoder', 'decoder', 'BACKEND_NAME', 'BACKEND_TYPE']


BACKEND_NAME = 'bitarray'
BACKEND_TYPE = bpack.EBaseUnits.BITS


FactoryType = Callable[[bitarray.bitarray], Any]


def ba_to_float_factory(size, byteorder: str = '>',
                        bitorder: str = 'big') -> FactoryType:
    """Convert a bitarray into a float."""
    assert bitorder == 'big'

    if size == 16:
        fmt = f'{byteorder}e'
    elif size == 32:
        fmt = f'{byteorder}f'
    elif size == 64:
        fmt = f'{byteorder}d'
    else:
        raise ValueError('floating point item size must be 16, 32 or 64 bits')

    codec = struct.Struct(fmt)

    def func(ba):
        return codec.unpack(ba.tobytes())[0]

    return func


def converter_factory(type_, size: Optional[int] = None, signed: bool = False,
                      byteorder: str = '>',
                      bitorder: str = 'big') -> FactoryType:
    if bpack.utils.is_sequence_type(type_, error=True):
        raise TypeError(
            f'backend "{BACKEND_NAME}" does not supports sequence types: '
            f'"{type_}".')
    etype = bpack.utils.effective_type(type_)
    if etype is int:
        def func(ba):
            return ba2int(ba, signed)
    elif etype is float:
        func = ba_to_float_factory(size, byteorder, bitorder)
    elif etype is bytes:
        def func(ba):
            return ba.tobytes()
    elif etype is str:
        def func(ba):
            return ba.tobytes().decode('ascii')
    elif etype is bool:
        def func(ba):
            return bool(bitarray.util.ba2int(ba))
    else:
        raise TypeError(
            f'type "{type_}" is not supported by the {__name__} backend'
            f'({BACKEND_NAME})')

    if etype is not type_:
        def converter(x, conv_func=func):
            return type_(conv_func(x))
    else:
        converter = func
    return converter


def _bitorder_to_baorder(bitorder: bpack.EBitOrder) -> str:
    if bitorder in {bpack.EBitOrder.MSB, bpack.EBitOrder.DEFAULT}:
        s = 'big'
    elif bitorder is bpack.EBitOrder.LSB:
        s = 'little'
    else:
        raise ValueError(f'invalid bit order: "{bitorder}"')
    return s


class Decoder:
    """Bitarray based data decoder.

    Only supports "big endian" byte-order and MSB bit-order.
    """

    baseunits = bpack.EBaseUnits.BITS

    def __init__(self, descriptor, converters=converter_factory):
        """Initializer.

        The *descriptor* parameter* is a bpack record descriptor.
        """
        if bpack.baseunits(descriptor) is not self.baseunits:
            raise ValueError(
                f'bitarray decoder only accepts descriptors with '
                f'base units "{self.baseunits}"')

        assert bpack.bitorder(descriptor) is not None

        byteorder = bpack.byteorder(descriptor)

        if byteorder in {bpack.EByteOrder.LITTLE, bpack.EByteOrder.NATIVE}:
            raise NotImplementedError(
                f'byte order "{byteorder}" is not supported by the {__name__} '
                f'backend ({BACKEND_NAME})')

        bitorder = _bitorder_to_baorder(bpack.bitorder(descriptor))
        if bitorder != 'big':
            raise NotImplementedError(
                f'bit order "{bitorder}" is not supported by the {__name__} '
                f'backend ({BACKEND_NAME})')

        if callable(converters):
            conv_factory = converters
            byteorder_str = byteorder.value if byteorder.value else '>'
            converters = [
                conv_factory(field_descr.type, field_descr.size,
                             field_descr.signed, byteorder_str)
                for field_descr in field_descriptors(descriptor)
            ]

        if converters is not None:
            converters = list(converters)
            n_fields = len(list(bpack.fields(descriptor)))
            if len(converters) != n_fields:
                raise ValueError(
                    f'the number of converters ({len(converters)}) does not '
                    f'match the number of fields ({n_fields})')

        self._descriptor = descriptor
        self._converters = converters
        self._slices = [
            slice(field_descr.offset, field_descr.offset + field_descr.size)
            for field_descr in field_descriptors(descriptor)
        ]

    def decode(self, data: bytes):
        """Decode binary data and return a record object."""
        ba = bitarray.bitarray()
        ba.frombytes(data)
        values = [ba[slice_] for slice_ in self._slices]

        if self._converters is not None:
            values = [
                convert(value) if convert is not None else value
                for convert, value in zip(self._converters, values)
            ]

        return self._descriptor(*values)


decoder = make_decoder_decorator(Decoder)
