from __future__ import annotations

import asyncio
import os
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING

from bittensor import Subtensor
from bittensor.core.metagraph import Metagraph

from bt_ddos_shield.blockchain_manager import (
    AbstractBlockchainManager,
    BittensorBlockchainManager,
    BlockchainManagerException,
)
from bt_ddos_shield.encryption_manager import AbstractEncryptionManager, ECIESEncryptionManager, EncryptionCertificate
from bt_ddos_shield.event_processor import AbstractMinerShieldEventProcessor, PrintingMinerShieldEventProcessor
from bt_ddos_shield.manifest_manager import (
    JsonManifestSerializer,
    Manifest,
    ManifestDeserializationException,
    ReadOnlyManifestManager,
)

if TYPE_CHECKING:
    import bittensor
    import bittensor_wallet
    from coincurve.keys import PrivateKey as CoincurvePrivateKey

    from bt_ddos_shield.utils import Hotkey, PublicKey


@dataclass
class ShieldMetagraphOptions:
    replace_ip_address_for_axon: bool = True
    """
    Determines how shield address is added to axon info in metagraph. If True, shield address will replace original one
    in `ip` field, otherwise new field `shield_address` will be added. In both cases, port will be replaced with port
    from shield.
    """


class ShieldMetagraph(Metagraph):
    """
    Wrapper class for Metagraph. It allows Validator to retrieve addresses generated by MinerShield instead of Miners
    addresses presented in original metagraph. If given Miner is not shielded, it will return their original address.

    To use this class in your code just replace your Metagraph instance with ShieldMetagraph instance. If you were
    using subtensor.metagraph() before, you should now use constructor with sync param set to True and other params
    set appropriately.

    certificate_path argument is path to file where certificate is stored. If file does not exist, new certificate
    will be generated and saved to this file. If file exists, certificate will be loaded from it.
    """

    wallet: bittensor_wallet.Wallet
    """ Validator's wallet. """
    certificate: EncryptionCertificate
    """ Certificate used for encryption of addresses generated for Validator by Miners. """
    event_processor: AbstractMinerShieldEventProcessor

    encryption_manager: AbstractEncryptionManager
    blockchain_manager: AbstractBlockchainManager
    manifest_manager: ReadOnlyManifestManager
    options: ShieldMetagraphOptions

    def __init__(
        self,
        wallet: bittensor_wallet.Wallet,
        netuid: int,
        network: str | None = None,
        lite: bool = True,
        sync: bool = True,
        block: int | None = None,
        subtensor: bittensor.Subtensor | None = None,
        event_processor: AbstractMinerShieldEventProcessor | None = None,
        encryption_manager: AbstractEncryptionManager | None = None,
        blockchain_manager: AbstractBlockchainManager | None = None,
        manifest_manager: ReadOnlyManifestManager | None = None,
        options: ShieldMetagraphOptions | None = None,
    ):
        if subtensor is None:
            subtensor = Subtensor(network=network)
        super().__init__(
            netuid=netuid,
            lite=lite,
            sync=False,
            subtensor=subtensor,
            **({'network': network} if network is not None else {}),
        )

        self.wallet = wallet
        self.options = options or ShieldMetagraphOptions()
        self.event_processor = event_processor or PrintingMinerShieldEventProcessor()
        self.encryption_manager = encryption_manager or self.create_default_encryption_manager()
        self.blockchain_manager = blockchain_manager or self.create_default_blockchain_manager(
            self.subtensor, netuid, wallet, self.event_processor
        )
        self.manifest_manager = manifest_manager or self.create_default_manifest_manager(
            self.event_processor, self.encryption_manager
        )
        self._init_certificate()

        if sync:
            self.sync(block=block, lite=lite, subtensor=self.subtensor)
        elif block is not None:
            raise ValueError('Block argument is valid only when sync is True')

    def _init_certificate(self) -> None:
        certificate_path: str = os.getenv('VALIDATOR_SHIELD_CERTIFICATE_PATH', './validator_cert.pem')
        try:
            coincurve_cert: CoincurvePrivateKey = self.encryption_manager.load_certificate(certificate_path)
            self.certificate = self.encryption_manager.serialize_certificate(coincurve_cert)
            public_key: PublicKey | None = self.blockchain_manager.get_own_public_key()
            if self.certificate.public_key == public_key:
                return
        except FileNotFoundError:
            coincurve_cert = self.encryption_manager.generate_certificate()
            self.encryption_manager.save_certificate(coincurve_cert, certificate_path)
            self.certificate = self.encryption_manager.serialize_certificate(coincurve_cert)

        try:
            self.blockchain_manager.upload_public_key(self.certificate.public_key)
        except BlockchainManagerException:
            # Retry once
            time.sleep(3)
            self.blockchain_manager.upload_public_key(self.certificate.public_key)

    @classmethod
    def create_default_encryption_manager(cls):
        return ECIESEncryptionManager()

    @classmethod
    def create_default_manifest_manager(
        cls,
        event_processor: AbstractMinerShieldEventProcessor,
        encryption_manager: AbstractEncryptionManager,
    ) -> ReadOnlyManifestManager:
        return ReadOnlyManifestManager(JsonManifestSerializer(), encryption_manager, event_processor)

    @classmethod
    def create_default_blockchain_manager(
        cls,
        subtensor: bittensor.Subtensor,
        netuid: int,
        wallet: bittensor_wallet.Wallet,
        event_processor: AbstractMinerShieldEventProcessor,
    ) -> AbstractBlockchainManager:
        return BittensorBlockchainManager(
            subtensor=subtensor,
            netuid=netuid,
            wallet=wallet,
            event_processor=event_processor,
        )

    def sync(self, block: int | None = None, lite: bool = True, subtensor: bittensor.Subtensor | None = None):
        super().sync(block=block, lite=lite, subtensor=subtensor)
        hotkeys: list[str] = self.hotkeys
        urls: dict[Hotkey, str | None] = asyncio.run(self.blockchain_manager.get_manifest_urls(hotkeys))
        manifests: dict[Hotkey, Manifest | None] = asyncio.run(self.manifest_manager.get_manifests(urls))
        own_hotkey: Hotkey = self.wallet.hotkey.ss58_address
        for axon in self.axons:
            manifest: Manifest | None = manifests.get(axon.hotkey)
            if manifest is not None:
                try:
                    shield_address: tuple[str, int] | None = self.manifest_manager.get_address_for_validator(
                        manifest, own_hotkey, self.certificate.private_key
                    )
                except ManifestDeserializationException as e:
                    self.event_processor.event(
                        'Error while getting shield address for miner {hotkey}', exception=e, hotkey=axon.hotkey
                    )
                    continue
                if shield_address is not None:
                    if self.options.replace_ip_address_for_axon:
                        axon.ip = shield_address[0]
                    else:
                        axon.shield_address = shield_address[0]
                    axon.port = shield_address[1]
