#!/usr/bin/env python

"""
Trading strategy runner
"""

import multiprocessing as mp
import os
import sys
import time
import uuid
import random
from datetime import datetime
from math import ceil
from typing import List

import alpaca_trade_api as tradeapi
import pygit2
import toml
from pytz import timezone

from liualgotrader.common import config
from liualgotrader.common.market_data import get_historical_data_from_polygon
from liualgotrader.common.tlog import tlog
from liualgotrader.consumer import consumer_main
from liualgotrader.polygon_producer import polygon_producer_main
from liualgotrader.scanners_runner import main


def motd(filename: str, version: str, unique_id: str) -> None:
    """Display welcome message"""

    print("+=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=+")
    tlog(f"{filename} {version} starting")
    tlog(f"unique id: {unique_id}")
    print("+=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=+")
    tlog(f"DSN: {config.dsn}")
    tlog(f"MAX SYMBOLS: {config.total_tickers}")
    print("+=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=+")


def get_trading_windows(tz, api):
    """Get start and end time for trading"""
    tlog("checking market schedule")
    today = datetime.today().astimezone(tz)
    today_str = datetime.today().astimezone(tz).strftime("%Y-%m-%d")

    calendar = api.get_calendar(start=today_str, end=today_str)[0]

    tlog(f"next open date {calendar.date.date()}")

    if today.date() < calendar.date.date():
        tlog(f"which is not today {today}")
        return None, None
    market_open = today.replace(
        hour=calendar.open.hour,
        minute=calendar.open.minute,
        second=0,
        microsecond=0,
    )
    market_close = today.replace(
        hour=calendar.close.hour,
        minute=calendar.close.minute,
        second=0,
        microsecond=0,
    )
    return market_open, market_close


"""
process main
"""

def ready_to_start(trading_api: tradeapi) -> bool:
    nyc = timezone("America/New_York")

    config.market_open, config.market_close = get_trading_windows(nyc, trading_api)

    if config.market_open or config.bypass_market_schedule:

        if not config.bypass_market_schedule:
            tlog(
                f"markets open {config.market_open} market close {config.market_close}"
            )

        # Wait until just before we might want to trade
        current_dt = datetime.today().astimezone(nyc)
        tlog(f"current time {current_dt}")

        if config.bypass_market_schedule:
            tlog("bypassing market schedule, are we debugging something?")
            return True
        elif current_dt < config.market_close:
            to_market_open = config.market_open - current_dt
            if to_market_open.total_seconds() > 0:
                try:
                    tlog(
                        f"waiting for market open: {to_market_open} ({to_market_open.total_seconds()} seconds)"
                    )
                    time.sleep(to_market_open.total_seconds() + 1)
                except KeyboardInterrupt:
                    return False

            return True

    return False


def calc_num_consumer_processes() -> int:
    if config.num_consumers > 0:
        return config.num_consumers

    num_cpus = os.cpu_count()

    load_avg = sum(os.getloadavg()) / 3

    tlog(f"15-min load_avg:{load_avg}, num_cpus:{num_cpus}, proc_factor:{config.proc_factor}")
    if not load_avg:
        load_avg = 1.0

    num_processes = max(
        1,
        ceil(
            num_cpus / min(load_avg, 1.0) * config.proc_factor
            if num_cpus > 0
            else config.proc_factor / min(load_avg, 1.0)
        ),
    )


    return num_processes


"""
starting
"""


if __name__ == "__main__":
    config.filename = os.path.basename(__file__)
    mp.set_start_method("spawn")

    try:
        config.build_label = pygit2.Repository("../").describe(
            describe_strategy=pygit2.GIT_DESCRIBE_TAGS
        )
    except pygit2.GitError:
        import liualgotrader

        config.build_label = liualgotrader.__version__ if hasattr(liualgotrader, "__version__") else ""  # type: ignore

    uid = str(uuid.uuid4())
    motd(filename=config.filename, version=config.build_label, unique_id=uid)

    # load configuration
    folder = config.tradeplan_folder if config.tradeplan_folder[-1] == '/' else f"{config.tradeplan_folder}/"
    fname = f"{folder}{config.configuration_filename}"
    try:

        conf_dict = toml.load(fname)
        tlog(
            f"loaded configuration file from {fname}"
        )
    except FileNotFoundError:
        tlog(f"[ERROR] could not locate tradeplan file {fname}")
        sys.exit(0)


    # parse configuration
    config.bypass_market_schedule = conf_dict.get("bypass_market_schedule", False)
    scanners_only = False if "test_scanners" not in conf_dict or not conf_dict["test_scanners"] else True

    # basic validation for scanners and strategies
    tlog(f"bypass_market_schedule = {config.bypass_market_schedule}")
    if "scanners" not in conf_dict or len(conf_dict["scanners"]) == 0:
        tlog("must have at least one scanner configured")
        exit(0)
    elif "strategies" not in conf_dict or len(conf_dict["strategies"]) == 0:
        tlog("must have at least one strategy configured")
        exit(0)

    scanners_conf = conf_dict["scanners"]
    for scanner in scanners_conf:
        tlog(f"- {scanner} scanner detected")


    data_api = tradeapi.REST(
        base_url=config.prod_base_url,
        key_id=config.prod_api_key_id,
        secret_key=config.prod_api_secret,
    )

    if ready_to_start(data_api):
        # add open positions
        symbols: List = []
        base_url = (
            config.prod_base_url if config.env == "PROD" else config.paper_base_url
        )
        api_key_id = (
            config.prod_api_key_id if config.env == "PROD" else config.paper_api_key_id
        )
        api_secret = (
            config.prod_api_secret if config.env == "PROD" else config.paper_api_secret

        )
        trading_api = tradeapi.REST(
            base_url=base_url, key_id=api_key_id, secret_key=api_secret
        )

        if "skip_existing" not in conf_dict or not conf_dict["skip_existing"] :
            existing_positions = trading_api.list_positions()

            if len(existing_positions) == 0:
                tlog("no open positions")
            else:
                for position in existing_positions:
                    if position.symbol not in symbols:
                        symbols.append(position.symbol)
                        tlog(f"added existing open position in {position.symbol}")
        else:
            tlog("skipping existing open positions")

        minute_history = get_historical_data_from_polygon(
            api=data_api,
            symbols=symbols,
            max_tickers=min(config.total_tickers, len(symbols)),
        )
        symbols = list(minute_history.keys())

        if not scanners_only:
            # Consumers first
            num_consumer_processes = calc_num_consumer_processes()
            tlog(f"Starting {num_consumer_processes} consumer processes")

            queues: List[mp.Queue] = [mp.Queue() for i in range(num_consumer_processes)]
            q_id_hash = {}
            symbol_by_queue = {}
            for symbol in symbols:
                _index = random.SystemRandom().randint(0, num_consumer_processes-1)
                q_id_hash[symbol] = _index
                if _index not in symbol_by_queue:
                    symbol_by_queue[_index] = [symbol]
                else:
                    symbol_by_queue[_index].append(symbol)

            consumers = [
                mp.Process(
                    target=consumer_main,
                    args=(
                        queues[i],
                        symbol_by_queue[i] if i in symbol_by_queue else None,
                        minute_history,
                        uid,
                        conf_dict,
                    ),
                )
                for i in range(num_consumer_processes)
            ]
            for p in consumers:
                # p.daemon = True
                p.start()

        scanner_queue: mp.Queue = mp.Queue()

        if not scanners_only:
            producer = mp.Process(
                target=polygon_producer_main,
                args=(
                    uid,
                    queues,
                    symbols,
                    q_id_hash,
                    config.market_close,
                    conf_dict,
                    scanner_queue,
                    num_consumer_processes,
                ),
            )
            producer.start()

        tlog("Starting scanners process")
        scanner = mp.Process(
            target=main,
            args=(
                conf_dict,
                config.market_open,
                config.market_close,
                scanner_queue,
            ),
        )
        scanner.start()

        # wait for completion and hope everyone plays nicely
        try:
            if not scanners_only:
                producer.join()
            scanner.join()

            if not scanners_only:
                for p in consumers:
                    p.join()

        except KeyboardInterrupt:
            if not scanners_only:
                producer.terminate()
            scanner.terminate()

            if not scanners_only:
                for p in consumers:
                    p.terminate()

    print("+=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=+")
    tlog(f"run {uid} completed")
    sys.exit(0)
