import logging
from datetime import datetime
from pathlib import Path
from pprint import PrettyPrinter
from typing import Any, Dict
from urllib.parse import urljoin

import black
import requests

from evmchains.types import Chain

INCLUDE_PROTOCOLS = ["http", "https"]
SOURCE_URL = (
    "https://raw.githubusercontent.com/ethereum-lists/chains/master/_data/chains/"
)
CHAIN_CONST_FILE = Path(__file__).parent.parent / "evmchains" / "chains.py"
BLACKLIST_STRINGS = [
    # 2024-01-19: Node appears to be broken.  Returning errors on simple requests.
    "rpc.blocknative.com"
]

# Mapping of Ape ecosystem:network to chain IDs. These are the chains that we will be fetching.
CHAIN_IDS = {
    "arbitrum": {
        "mainnet": 42161,
        "goerli": 421613,
        "sepolia": 421614,
    },
    "avalanche": {
        "mainnet": 43114,
        "fuji": 43113,
    },
    "base": {
        "mainnet": 8453,
        "sepolia": 84532,
    },
    "blast": {
        "mainnet": 81457,
        "sepolia": 168587773,
    },
    "bsc": {
        "mainnet": 56,
        "testnet": 97,
    },
    "ethereum": {
        "mainnet": 1,
        "goerli": 5,
        "sepolia": 11155111,
    },
    "fantom": {
        "mainnet": 250,
        "testnet": 4002,
    },
    "gnosis": {
        "mainnet": 100,
    },
    "optimism": {
        "mainnet": 10,
        "goerli": 420,
        "sepolia": 11155420,
    },
    "oort": {
        "mainnet": 970,
        "dev": 9700,
    },
    "polygon": {
        "mainnet": 137,
        "mumbai": 80001,
    },
    "polygon-zkevm": {
        "mainnet": 1101,
        "testnet": 1442,
    },
    "linea": {
        "mainnet": 59144,
        "sepolia": 59141,
    },
}

pp = PrettyPrinter(indent=4)
logger = logging.getLogger("update")
logger.setLevel(logging.DEBUG)


def stamp() -> str:
    """UTC timestamp for file header"""
    return str(datetime.utcnow())


def ensure_dict(d: Dict[str, Any], key: str):
    if key in d and isinstance(d[key], dict):
        return
    d[key] = dict()


def is_uri_blacklisted(uri: str) -> bool:
    """Check if a URI is blacklisted."""
    for blacklisted in BLACKLIST_STRINGS:
        if blacklisted in uri:
            return True
    return False


def fetch_chain(chain_id: int) -> Chain:
    """Fetch a chain from the ethereum-lists repo."""
    url = urljoin(SOURCE_URL, f"eip155-{chain_id}.json")

    logger.info(f"GET {url}")
    r = requests.get(url)
    r.raise_for_status()

    chain = Chain.model_validate_json(r.text)

    # Filter out blacklised URIs
    chain.rpc = list(filter(lambda rpc: not is_uri_blacklisted(rpc), chain.rpc))

    # Filter out unwanted protocols (e.g. websocket)
    chain.rpc = list(
        filter(lambda rpc: rpc.split(":")[0] in INCLUDE_PROTOCOLS, chain.rpc)
    )

    return chain


def fetch_chains() -> Dict[str, Dict[str, Chain]]:
    """Fetch all chains from the ethereum-lists repo."""
    chains: Dict[str, Dict[str, Chain]] = {}
    for ecosystem in CHAIN_IDS.keys():
        for network, chain_id in CHAIN_IDS[ecosystem].items():
            logger.info(f"Fetching chain {ecosystem}:{network} ({chain_id})")
            ensure_dict(chains, ecosystem)
            chains[ecosystem][network] = fetch_chain(chain_id)
    return chains


def write_chain_const(chains: Dict[str, Dict[str, Chain]]):
    """Write the file with Python constant"""
    file_str = "# This file is auto-generated by scripts/update.py\n"
    file_str += f"# {stamp()}\n"
    file_str += "# Do not edit this file directly.\n"
    file_str += "from typing import Any, Dict\n\n"
    file_str += "PUBLIC_CHAIN_META: Dict[str, Dict[str, Dict[str, Any]]] = {\n"

    for ecosystem in chains.keys():
        file_str += f'    "{ecosystem}": {{\n'
        for network, chain in chains[ecosystem].items():
            # pprint to make a somewhat pythonic string output
            file_str += f'        "{network}": {pp.pformat(chain.model_dump())},\n'
        file_str += "    },\n"
    file_str += "}\n"

    # black to make it actually readable
    file_str = black.format_file_contents(file_str, fast=False, mode=black.FileMode())

    with CHAIN_CONST_FILE.open("w") as const_file:
        const_file.write(file_str)


def main():
    logger.info("Fetching chain data...")
    logger.info(f"    Source: {SOURCE_URL}")
    logger.info(f"    Dest: {CHAIN_CONST_FILE}")
    chains = fetch_chains()
    write_chain_const(chains)


if __name__ == "__main__":
    main()
