import json
from copy import copy
from enum import Enum
from hashlib import sha256
from json import JSONDecodeError
from typing import List, Dict, Any, Optional, Union

try:
    from typing import Literal
except ImportError:
    from typing_extensions import Literal

from pydantic import BaseModel, Extra, Field, validator

from .abstract import BaseContent
from .program import ProgramContent


class Chain(str, Enum):
    "Supported chains"
    AVAX = "AVAX"
    CSDK = "CSDK"
    DOT = "DOT"
    ETH = "ETH"
    NEO = "NEO"
    NULS = "NULS"
    NULS2 = "NULS2"
    SOL = "SOL"


class HashType(str, Enum):
    "Supported hash functions"
    sha256 = "sha256"


class MessageType(str, Enum):
    "Message types supported by Aleph"
    post = "POST"
    aggregate = "AGGREGATE"
    store = "STORE"
    program = "PROGRAM"
    forget = "FORGET"


class ItemType(str, Enum):
    "Item storage options"
    inline = "inline"
    storage = "storage"
    ipfs = "ipfs"


class MongodbId(BaseModel):
    "PyAleph returns an internal MongoDB id"
    oid: str = Field(alias="$oid")

    class Config:
        extra = Extra.forbid


class ChainRef(BaseModel):
    "Some POST messages have a 'ref' field referencing other content"
    chain: Chain
    channel: Optional[str]
    item_content: str
    item_hash: str
    item_type: ItemType
    sender: str
    signature: str
    time: float
    type = "POST"


class MessageConfirmationHash(BaseModel):
    binary: str = Field(alias="$binary")
    type: str = Field(alias="$type")

    class Config:
        extra = Extra.forbid


class MessageConfirmation(BaseModel):
    "Format of the result when a message has been confirmed on a blockchain"
    chain: Chain
    height: int
    hash: Union[str, MessageConfirmationHash]

    class Config:
        extra = Extra.forbid


class AggregateContentKey(BaseModel):
    name: str

    class Config:
        extra = Extra.forbid


class PostContent(BaseContent):
    "Content of a POST message"
    content: Optional[Any] = Field(
        description="User-generated content of a POST message"
    )
    ref: Optional[Union[str, ChainRef]] = Field(
        description="Other message referenced by this one",
        default=None,
    )
    type: str = Field(description="User-generated 'content-type' of a POST message")

    @validator("type")
    def check_type(cls, v, values):
        if v == "amend":
            ref = values.get("ref")
            if not ref:
                raise ValueError("A 'ref' is required for POST type 'amend'")
        return v

    class Config:
        extra = Extra.forbid


class AggregateContent(BaseContent):
    "Content of an AGGREGATE message"
    key: Union[str, AggregateContentKey] = Field(
        description="The aggregate key can be either a string of a dict containing the key in field 'name'"
    )
    content: Union[Dict, List, str, int, float, bool, None] = Field(
        description="The content of an aggregate must be a dict"
    )

    class Config:
        extra = Extra.forbid


class StoreContent(BaseContent):
    "Content of a STORE message"
    item_type: ItemType
    item_hash: str
    size: Optional[int]  # Generated by the node on storage
    content_type: Optional[str]  # Generated by the node on storage
    ref: Optional[str]

    class Config:
        extra = Extra.allow


class ForgetContent(BaseContent):
    "Content of a FORGET message"
    hashes: List[str]
    reason: Optional[str]

    def __hash__(self):
        # Convert List to Tuple for hashing
        values = copy(self.__dict__)
        values["hashes"] = tuple(values["hashes"])
        return hash(self.__class__) + hash(values.values())


class BaseMessage(BaseModel):
    "Base template for all messages"
    id_: Optional[MongodbId] = Field(alias="_id", description="MongoDB metadata")
    chain: Chain = Field(description="Blockchain used for this message")

    sender: str = Field(description="Address of the sender")
    type: MessageType = Field(description="Type of message (POST, AGGREGATE or STORE)")
    channel: Optional[str] = Field(
        description="Channel of the message, one application ideally has one channel"
    )
    confirmations: Optional[List[MessageConfirmation]] = Field(
        description="Blockchain confirmations of the message"
    )
    confirmed: Optional[bool] = Field(
        description="Indicates that the message has been confirmed on a blockchain"
    )
    signature: str = Field(
        description="Cryptographic signature of the message by the sender"
    )
    size: Optional[int] = Field(
        description="Size of the content"
    )  # Almost always present
    time: float = Field(description="Unix timestamp when the message was published")
    item_type: ItemType = Field(description="Storage method used for the content")
    item_content: Optional[str] = Field(
        description="JSON serialization of the message when 'item_type' is 'inline'"
    )
    hash_type: Optional[HashType] = Field(
        description="Hashing algorithm used to compute 'item_hash'"
    )
    item_hash: str = Field(description="Hash of the content (sha256 by default)")
    content: BaseContent = Field(description="Content of the message, ready to be used")

    @validator("item_content")
    def check_item_content(cls, v: Optional[str], values):
        item_type = values["item_type"]
        if item_type == ItemType.inline:
            try:
                json.loads(v)
            except JSONDecodeError:
                raise ValueError(
                    "Field 'item_content' does not appear to be valid JSON"
                )
        else:
            if v != None:
                raise ValueError(
                    f"Field 'item_content' cannot be defined when 'item_type' == '{item_type}'"
                )
        return v

    @validator("item_hash")
    def check_item_hash(cls, v, values):
        item_type = values["item_type"]
        if item_type == ItemType.inline:
            item_content: str = values["item_content"]

            # Double check that the hash function is supported
            hash_type = values["hash_type"] or HashType.sha256
            assert hash_type.value == HashType.sha256

            computed_hash: str = sha256(item_content.encode()).hexdigest()
            if v != computed_hash:
                raise ValueError(f"'item_hash' do not match 'sha256(item_content)'"
                                 f", expecting {computed_hash}")
        elif item_type == ItemType.ipfs:
            # TODO: CHeck that the hash looks like an IPFS multihash
            pass
        else:
            assert item_type == ItemType.storage
        return v

    @validator("confirmed")
    def check_confirmed(cls, v, values):
        confirmations = values["confirmations"]
        if v != bool(confirmations):
            raise ValueError("Message cannot be 'confirmed' without 'confirmations'")
        return v

    class Config:
        extra = Extra.forbid


class PostMessage(BaseMessage):
    """Unique data posts (unique data points, events, ...)"""

    type: Literal[MessageType.post]
    content: PostContent


class AggregateMessage(BaseMessage):
    """A key-value storage specific to an address"""

    type: Literal[MessageType.aggregate]
    content: AggregateContent


class StoreMessage(BaseMessage):
    type: Literal[MessageType.store]
    content: StoreContent


class ForgetMessage(BaseMessage):
    type: Literal[MessageType.forget]
    content: ForgetContent


class ProgramMessage(BaseMessage):
    type: Literal[MessageType.program]
    content: ProgramContent

    @validator("content")
    def check_content(cls, v, values):
        item_type = values["item_type"]
        if item_type == ItemType.inline:
            item_content = json.loads(values["item_content"])
            if v.dict(exclude_none=True) != item_content:
                # Print differences
                vdict = v.dict()
                for key, value in item_content.items():
                    if vdict[key] != value:
                        print(f"{key}: {vdict[key]} != {value}")
                raise ValueError("Content and item_content differ")
        return v


def Message(**message_dict: Dict
            ) -> Union[PostMessage, AggregateMessage, StoreMessage, ProgramMessage]:
    "Returns the message class corresponding to the type of message."
    for raw_type, message_class in {
        MessageType.post: PostMessage,
        MessageType.aggregate: AggregateMessage,
        MessageType.store: StoreMessage,
        MessageType.program: ProgramMessage,
        MessageType.forget: ForgetMessage,
    }.items():
        if message_dict["type"] == raw_type:
            return message_class(**message_dict)
    else:
        raise ValueError(f"Unknown message type")


class MessagesResponse(BaseModel):
    "Response from an Aleph node API."
    messages: List[Union[PostMessage, AggregateMessage, StoreMessage, ProgramMessage]]
    pagination_page: int
    pagination_total: int
    pagination_per_page: int
    pagination_item: str

    class Config:
        extra = Extra.forbid
