"""
Serialize object that can encode data into a byte array.
"""
# Author: Justin Roosenschoon <jeroosenschoon@gmail.com>

# Licence: MIT License (c) 2021 Justin Roosenschoon

import re
from serializeme.field import Field

# Constants representing various ways to handle variable-length data.
NULL_TERMINATE = "null_terminate"  # Data + byte of zeros
PREFIX_LENGTH = "prefix_length"  # Length of data (in bytes) + Data
PREFIX_LEN_NULL_TERM = "prefix_len_null_term"  # Length of data (in bytes) + Data + bytes of zeros
IPv4 = "ipv4"
VAR_PREFIXES = [NULL_TERMINATE, PREFIX_LENGTH, PREFIX_LEN_NULL_TERM]


class Serialize:
    """
    Serialize object that can encode data into a byte array. This allows users to enter values and the desired size (in
    bits or bytes), and the Serialize object will automatically handle the conversion of these values to a byte array.
    Serialize also supports variable-length data by one of the above constants.
    Parameters
    ----------
    :param data: dictionary
        The data to be converted to a byte array. The dictionary key-value pairs are of the form field_name: size/value
        size/vlaue has a variety of options we can specify:
            - (a, b): 2-tuple of integers where a is the number of bits, and b is the value.
            - (a, b): 2-tuple of a string and integer where a (the string) is of the form "xb" for x-bits or "xB" for
                      x-bytes, and b is an integer representing the value of the field.
            - (a, b): 2-tuple of constant (from above) and list of strings will operate on each string in a way dictated
                      by the constant specified.
            - a: int specifying the number of zero-bits the field is.
            - a: str where a (the string) is of the form "xb" for the number of zero x-bits or "xB" for the number of
                  zero x-bytes
        If any of the values cannot be held in the specified number of bits or bytes, an exception will be thrown.

    Attributes
    ----------
    fields: list of Field objects representing the Fields that were generated by the dictionary.

    """

    def __init__(self, data):
        self.data = data

        self.fields = []

        self.__extract_fields()

    def __bits_to_bytes(self, bit_str):
        """
        Helper function that will convert a string of bits into a byte array.
        :param bit_str: The string of bits to convert.
        :return: A byte array representing the specified bits.
        """
        bit_str = bit_str.replace(" ", "")
        return int(bit_str, 2).to_bytes((len(bit_str) + 7) // 8, byteorder='big')

    def packetize(self):
        """
        Generate a byte string from the list of fields in the object.
        :return: A byte string of the fields.
        """
        # We shall generate a string of bits from the fields and then convert that bit string to a byte string.
        bit_str = ""
        for field in self.fields:
            if field.size == 1:  # One bit, just add it to our bit string.
                bit_str += str(field.value)
            else:
                if field.size not in [NULL_TERMINATE, PREFIX_LENGTH, PREFIX_LEN_NULL_TERM, IPv4]:
                    # Fixed size. Generate number of zeros to make the specified size
                    # and then the binary of the field's value.
                    if isinstance(field.value, int):
                        bit_str += "0" * (field.size - len(bin(field.value)[2:])) + bin(field.value)[2:]
                    else:
                        if isinstance(field.value, bytes):
                            bit_str += field.value.decode('latin-1')
                if field.size == IPv4:
                    # IPv4 address.
                    parts = field.value.split(".")
                    for part in parts:
                        bit_str += "0"*(8 - len(bin(int(part))[2:])) + bin(int(part))[2:]

                if field.size == PREFIX_LENGTH or field.size == PREFIX_LEN_NULL_TERM:
                    # Need to prefix the size of the value (in bytes) before adding the data.
                    # User can enter a list of values to be handled, so loop through each one.
                    if isinstance(field.value, str):
                        # Add byte representing length.
                        length_byte = "0" * (8 - len(bin(len(field.value))[2:])) + bin(len(field.value))[2:]
                        bit_str += length_byte
                        # Add data directly - no conversions
                        bit_str += field.value
                    else:
                        for f in field.value:
                            # Add byte representing length.
                            length_byte = "0" * (8 - len(bin(len(f))[2:])) + bin(len(f))[2:]
                            bit_str += length_byte
                            # Add data directly - no conversions
                            bit_str += f
                elif field.size == NULL_TERMINATE:
                    # Just need to add the data (without converting to binary) +  a byte of zeros.
                    bit_str += field.value + "0" * 8
                if field.size == PREFIX_LEN_NULL_TERM:
                    # Added length prefixes above. Now need to add the binary version of our data a byte of zeros.
                    bit_str += "0" * 8

        # Convert bit string to byte array.
        # Build the byte string byte-by-byte by collecting bits until we get 8 and then add the byte of this, and reset
        # the temp_byte.
        # Or if the character is an actual character (occurs when we use variable-length fields), start accumulating
        # those until we hit a 1 or 0, then add the encoded word to the byte array and reset the temp_word.
        b_array = b''
        temp_byte = ""
        temp_word = ""
        for c in bit_str:
            if len(temp_byte) == 8:
                # Accumulated 8-bits. Add to byte array and reset temp_byte to begin accumulation again.
                b_array += self.__bits_to_bytes(temp_byte)
                temp_byte = ""
            if c == "0" or c == "1":
                # Encountered a 1 or 0. Add any letters we may have to the byte array as their corresponding byte (or
                # nothing if we have not accumulated any letters), reset the temp_word,
                # and accumulate the bit in temp_byte.
                b_array += temp_word.encode()
                temp_word = ""
                temp_byte += c
            else:
                # We have non-bit. Accumulate in the temp_word.
                # If there are any stray 0s or 1s, add the encoded version to our array.
                if temp_byte:
                    b_array += temp_byte.encode()
                    temp_byte = ""
                temp_word += c
        # Add any leftovers to the byte array.
        if temp_byte:
            b_array += self.__bits_to_bytes(temp_byte)
        if temp_word != "":
            b_array += temp_word.encode()
        return b_array

    def get_field(self, field_name):
        """
        Get a specified field from the fields list, or return None if specified field does not exist.
        :param field_name: The name of the desired field to find.
        :return: Field: Field object with the specified name.
        """
        for f in self.fields:
            if f.name.lower() == field_name.lower():
                return f
        return None

    # Helper
    def __check_bit_size(self, value, num_bits):
        """
        Helper function to check if the specified value can fit in the specified number of bits.
        :param value: The value trying to fit in num_bits
        :param num_bits: The number of bits we want to see if value can fit in.
        :return: True if value can fit in num_bits number of bits. False otherwise.
        """
        is_fit = False
        if value <= 2 ** num_bits - 1:
            is_fit = True
        return is_fit

    # Helper
    def __extract_fields(self):
        """
        Helper function to parse the user-specified dictionary of fields upon creation of Serialize object.
        """
        for name, stuff in self.data.items():
            if stuff == ():  # Empty tuple == 1 bit, value of 0
                self.fields.append(Field(name=name, value=0, size=1))
            elif isinstance(stuff, int):  # int == specified value, value of 0
                self.fields.append(Field(name=name, value=0, size=stuff))
            elif isinstance(stuff, str):  # str == specified value, value of 0
                pattern = re.compile("[0-9]+[bB]")
                if pattern.match(stuff):
                    if "b" in stuff: # bits specified
                        size = int(stuff[:stuff.lower().index("b")])
                        self.fields.append(Field(name=name, value=0, size=size))
                    elif "B" in stuff: # Bytes specified
                        size = int(stuff[:stuff.lower().index("b")]) * 8
                        self.fields.append(Field(name=name, value=0, size=size))
                else: # No other string option, so must have been one of the "vary" constants from above.
                    self.fields.append(Field(name=name, value=stuff, size="vary"))
            elif isinstance(stuff, tuple) or isinstance(stuff, list):  # specified value and size.
                if isinstance(stuff[0], str):
                    if "b" in stuff[0]: # Bits
                        size = int(stuff[0][:stuff[0].lower().index("b")])
                       # if not self.__check_bit_size(stuff[1], size):
                        #    raise Exception("error. " + str(stuff[1]) + " cannot be fit in " + str(size) + " bits.")
                        self.fields.append(Field(name=name, value=stuff[1], size=size))
                    elif "B" in stuff[0]: # Bytes
                        size = int(stuff[0][:stuff[0].lower().index("b")]) * 8
                       # if not self.__check_bit_size(stuff[1], size):
                         #   raise Exception("error. " + str(stuff[1]) + " cannot be fit in " + str(size) + " bits.")
                        self.fields.append(Field(name=name, value=stuff[1], size=size))
                    elif stuff[0].lower() == NULL_TERMINATE:
                        self.fields.append(Field(name=name, value=stuff[1], size=NULL_TERMINATE))
                    elif stuff[0].lower() == PREFIX_LENGTH:
                        self.fields.append(Field(name=name, value=stuff[1], size=PREFIX_LENGTH))
                    elif stuff[0].lower() == PREFIX_LEN_NULL_TERM:
                        self.fields.append(Field(name=name, value=stuff[1], size=PREFIX_LEN_NULL_TERM))
                    elif stuff[0].lower() == IPv4:
                        self.fields.append(Field(name=name, value=stuff[1], size=IPv4))
                elif isinstance(stuff[0], int):
                   # if not self.__check_bit_size(stuff[1], stuff[0]):
                     #   raise Exception("error. " + str(stuff[1]) + " cannot be fit in " + str(stuff[0]) + " bits.")
                    self.fields.append(Field(name=name, value=stuff[1], size=stuff[0]))

    def __str__(self):
        """
        Generate a string representation of the Serialize object by listing out all of the fields, their value,
        and their sizes.
        :return: A string representation of the Serialize object.
        """
        s = ""
        for field in self.fields:
            if field.size not in VAR_PREFIXES:
                s += field.name + ": " + str(field.size) + " bits with value " + str(field.value) + ".\n"
            else:
                s += field.name + ": variable size: " + str(field.size) + ", with value " + str(field.value) + ".\n"

        return s



