import typing
from enum import Enum
from urllib.parse import unquote_plus

import multipart
from multipart.multipart import parse_options_header

from .datastructures import FormData, Headers, UploadFile


__all__ = ["AsyncFormParser", "FormParser", "AsyncMultiPartParser", "MultiPartParser"]


class FormMessage(Enum):
    FIELD_START = 1
    FIELD_NAME = 2
    FIELD_DATA = 3
    FIELD_END = 4
    END = 5


class MultiPartMessage(Enum):
    PART_BEGIN = 1
    PART_DATA = 2
    PART_END = 3
    HEADER_FIELD = 4
    HEADER_VALUE = 5
    HEADER_END = 6
    HEADERS_FINISHED = 7
    END = 8


def _user_safe_decode(src: bytes, codec: str) -> str:
    try:
        return src.decode(codec)
    except (UnicodeDecodeError, LookupError):
        return src.decode("latin-1")


class BaseFormParser:
    def __init__(self, headers: Headers) -> None:
        self.headers = headers
        self.callbacks = {
            "on_field_start": self.on_field_start,
            "on_field_name": self.on_field_name,
            "on_field_data": self.on_field_data,
            "on_field_end": self.on_field_end,
            "on_end": self.on_end,
        }
        self.state = dict(field_name=b"", field_value=b"")
        self.items: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]] = []

    def on_field_start(self) -> None:
        self.state["field_name"] = b""
        self.state["field_value"] = b""

    def on_field_name(self, data: bytes, start: int, end: int) -> None:
        self.state["field_name"] += data[start:end]

    def on_field_data(self, data: bytes, start: int, end: int) -> None:
        self.state["field_value"] += data[start:end]

    def on_field_end(self) -> None:
        name = unquote_plus(self.state["field_name"].decode("latin-1"))
        value = unquote_plus(self.state["field_value"].decode("latin-1"))
        self.items.append((name, value))

    def on_end(self) -> None:
        pass  # nothing to do


class AsyncFormParser(BaseFormParser):
    def __init__(
        self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]
    ) -> None:
        super().__init__(headers)
        self.stream = stream

    async def parse(self) -> FormData:
        # Create the parser.
        parser = multipart.QuerystringParser(self.callbacks)

        # Feed the parser with data from the request.
        async for chunk in self.stream:
            if chunk:
                parser.write(chunk)
            else:
                parser.finalize()

        return FormData(self.items)


class FormParser(BaseFormParser):
    def __init__(
        self, headers: Headers, stream: typing.Iterable[typing.Optional[bytes]]
    ) -> None:
        super().__init__(headers)
        self.stream = stream

    def parse(self) -> FormData:
        # Create the parser.
        parser = multipart.QuerystringParser(self.callbacks)

        # Feed the parser with data from the request.
        for chunk in self.stream:
            if chunk:
                parser.write(chunk)
            else:
                parser.finalize()

        return FormData(self.items)


class BaseMultiPartParser:
    def __init__(self, headers: Headers) -> None:
        self.headers = headers
        self.callbacks = {
            "on_part_begin": self.on_part_begin,
            "on_part_data": self.on_part_data,
            "on_part_end": self.on_part_end,
            "on_header_field": self.on_header_field,
            "on_header_value": self.on_header_value,
            "on_header_end": self.on_header_end,
            "on_headers_finished": self.on_headers_finished,
            "on_end": self.on_end,
        }
        self.messages: typing.List[typing.Tuple[MultiPartMessage, bytes]] = []

    def on_part_begin(self) -> None:
        message = (MultiPartMessage.PART_BEGIN, b"")
        self.messages.append(message)

    def on_part_data(self, data: bytes, start: int, end: int) -> None:
        message = (MultiPartMessage.PART_DATA, data[start:end])
        self.messages.append(message)

    def on_part_end(self) -> None:
        message = (MultiPartMessage.PART_END, b"")
        self.messages.append(message)

    def on_header_field(self, data: bytes, start: int, end: int) -> None:
        message = (MultiPartMessage.HEADER_FIELD, data[start:end])
        self.messages.append(message)

    def on_header_value(self, data: bytes, start: int, end: int) -> None:
        message = (MultiPartMessage.HEADER_VALUE, data[start:end])
        self.messages.append(message)

    def on_header_end(self) -> None:
        message = (MultiPartMessage.HEADER_END, b"")
        self.messages.append(message)

    def on_headers_finished(self) -> None:
        message = (MultiPartMessage.HEADERS_FINISHED, b"")
        self.messages.append(message)

    def on_end(self) -> None:
        message = (MultiPartMessage.END, b"")
        self.messages.append(message)

    def parse_options_header(self):
        # Parse the Content-Type header to get the multipart boundary.
        content_type, params = parse_options_header(self.headers["Content-Type"])
        charset = params.get(b"charset", "utf-8")
        if type(charset) == bytes:
            charset = charset.decode("latin-1")
        boundary = params.get(b"boundary")
        return boundary, charset


class AsyncMultiPartParser(BaseMultiPartParser):
    def __init__(
        self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]
    ) -> None:
        super().__init__(headers)
        self.stream = stream

    async def parse(self) -> FormData:
        boundary, charset = self.parse_options_header()

        # Create the parser.
        parser = multipart.MultipartParser(boundary, self.callbacks)
        header_field = b""
        header_value = b""
        content_disposition = None
        content_type = b""
        field_name = ""
        data = b""
        file: typing.Optional[UploadFile] = None

        items: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]] = []

        # Feed the parser with data from the request.
        async for chunk in self.stream:
            parser.write(chunk)
            messages = list(self.messages)
            self.messages.clear()
            for message_type, message_bytes in messages:
                if message_type == MultiPartMessage.PART_BEGIN:
                    content_disposition = None
                    content_type = b""
                    data = b""
                elif message_type == MultiPartMessage.HEADER_FIELD:
                    header_field += message_bytes
                elif message_type == MultiPartMessage.HEADER_VALUE:
                    header_value += message_bytes
                elif message_type == MultiPartMessage.HEADER_END:
                    field = header_field.lower()
                    if field == b"content-disposition":
                        content_disposition = header_value
                    elif field == b"content-type":
                        content_type = header_value
                    header_field = b""
                    header_value = b""
                elif message_type == MultiPartMessage.HEADERS_FINISHED:
                    disposition, options = parse_options_header(content_disposition)
                    field_name = _user_safe_decode(options[b"name"], charset)
                    if b"filename" in options:
                        filename = _user_safe_decode(options[b"filename"], charset)
                        file = UploadFile(
                            filename=filename,
                            content_type=content_type.decode("latin-1"),
                        )
                    else:
                        file = None
                elif message_type == MultiPartMessage.PART_DATA:
                    if file is None:
                        data += message_bytes
                    else:
                        await file.awrite(message_bytes)
                elif message_type == MultiPartMessage.PART_END:
                    if file is None:
                        items.append((field_name, _user_safe_decode(data, charset)))
                    else:
                        await file.aseek(0)
                        items.append((field_name, file))
                elif message_type == MultiPartMessage.END:
                    pass

        parser.finalize()
        return FormData(items)


class MultiPartParser(BaseMultiPartParser):
    def __init__(self, headers: Headers, stream: typing.Iterable[bytes]) -> None:
        super().__init__(headers)
        self.stream = stream

    def parse(self) -> FormData:
        boundary, charset = self.parse_options_header()

        # Create the parser.
        parser = multipart.MultipartParser(boundary, self.callbacks)
        header_field = b""
        header_value = b""
        content_disposition = None
        content_type = b""
        field_name = ""
        data = b""
        file: typing.Optional[UploadFile] = None

        items: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]] = []

        # Feed the parser with data from the request.
        for chunk in self.stream:
            parser.write(chunk)
            messages = list(self.messages)
            self.messages.clear()
            for message_type, message_bytes in messages:
                if message_type == MultiPartMessage.PART_BEGIN:
                    content_disposition = None
                    content_type = b""
                    data = b""
                elif message_type == MultiPartMessage.HEADER_FIELD:
                    header_field += message_bytes
                elif message_type == MultiPartMessage.HEADER_VALUE:
                    header_value += message_bytes
                elif message_type == MultiPartMessage.HEADER_END:
                    field = header_field.lower()
                    if field == b"content-disposition":
                        content_disposition = header_value
                    elif field == b"content-type":
                        content_type = header_value
                    header_field = b""
                    header_value = b""
                elif message_type == MultiPartMessage.HEADERS_FINISHED:
                    disposition, options = parse_options_header(content_disposition)
                    field_name = _user_safe_decode(options[b"name"], charset)
                    if b"filename" in options:
                        filename = _user_safe_decode(options[b"filename"], charset)
                        file = UploadFile(
                            filename=filename,
                            content_type=content_type.decode("latin-1"),
                        )
                    else:
                        file = None
                elif message_type == MultiPartMessage.PART_DATA:
                    if file is None:
                        data += message_bytes
                    else:
                        file.write(message_bytes)
                elif message_type == MultiPartMessage.PART_END:
                    if file is None:
                        items.append((field_name, _user_safe_decode(data, charset)))
                    else:
                        file.seek(0)
                        items.append((field_name, file))
                elif message_type == MultiPartMessage.END:
                    pass

        parser.finalize()
        return FormData(items)
