#!/usr/bin/env python3

import argparse
import logging
import os
import os.path
import sys
import time
import typing

from datetime import timedelta
from solana.publickey import PublicKey

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import mango  # nopep8

#
# Runs the 'Keeper' instructions for a Mango Group.
#
# The Keeper is similar to Serum's 'crank' - it runs some on-chain code that needs to be run
# periodically but that can't be triggered by the chain. For example it refreshes caches, and
# updates the current perp funding.
#
# It needs to send a lot of instructions, and these are automatically split into different
# transactions (based on instruction size) because of the maximum transaction size limit. It
# usually takes 4 or 5 transactions to run all the Keeper instructions.
#
# By default, a Group will have issues if a keeper does not run an update at least
# every 10 seconds.
#
# To run the keeper instructions every 5 seconds, run:
# ```
# keeper --pulse-interval 5
# ```
#
# It's also possible to run the keeper as a 'backup', so that it only runs the instructions
# if necessary - for example, if someone else regularly runs a keeper but it has failed for
# some reason.
#
# To run the keeper instructions only if they haven't been run in the last 8 seconds:
# ```
# keeper --pulse-interval 1 --skip-interval 8
# ```
#
# Alternatively, to just see what instructions would be executed, run:
# ```
# keeper --log-level DEBUG --pulse-interval -1 --dry-run
# ```
#
parser = argparse.ArgumentParser(
    description="Runs 'Keeper'/crank functionality for a Mango Group."
)
mango.ContextBuilder.add_command_line_parameters(parser)
mango.Wallet.add_command_line_parameters(parser)
parser.add_argument(
    "--pulse-interval",
    type=float,
    default=10,
    help="pause this number of seconds between each check",
)
parser.add_argument(
    "--skip-interval",
    type=float,
    help="skip the update if the cache has been update within this last number of seconds",
)
parser.add_argument(
    "--dry-run",
    action="store_true",
    default=False,
    help="runs as read-only and does not perform any transactions",
)
args: argparse.Namespace = mango.parse_args(parser)


def build_update_cache_instructions(
    context: mango.Context,
    group: mango.Group,
    perp_markets: typing.Sequence[mango.PerpMarketDetails],
) -> mango.CombinableInstructions:
    instructions = mango.CombinableInstructions.empty()

    root_banks = [
        group.shared_quote.root_bank_address,
        *[
            slot.base_token_bank.root_bank_address
            for slot in group.slots
            if slot.base_token_bank is not None
        ],
    ]
    instructions += mango.build_mango_cache_root_banks_instructions(
        context, group, root_banks
    )

    instructions += mango.build_mango_cache_prices_instructions(
        context, group, group.oracles
    )

    perp_market_addresses = [pm.address for pm in perp_markets]
    instructions += mango.build_mango_cache_perp_markets_instructions(
        context, group, perp_market_addresses
    )

    return instructions


def build_update_funding_instructions(
    context: mango.Context,
    group: mango.Group,
    node_bank_addresses: typing.Dict[str, typing.Sequence[PublicKey]],
    perp_markets: typing.Sequence[mango.PerpMarketDetails],
) -> mango.CombinableInstructions:
    instructions = mango.CombinableInstructions.empty()

    for slot in group.slots:
        if slot.base_token_bank is not None:
            base_nodebanks = node_bank_addresses[
                str(slot.base_token_bank.root_bank_address)
            ]
            instructions += mango.build_mango_update_root_bank_instructions(
                context, group, slot.base_token_bank.root_bank_address, base_nodebanks
            )

    quote_nodebanks = node_bank_addresses[str(group.shared_quote.root_bank_address)]
    instructions += mango.build_mango_update_root_bank_instructions(
        context, group, group.shared_quote.root_bank_address, quote_nodebanks
    )

    for perp_market in perp_markets:
        instructions += mango.build_mango_update_funding_instructions(
            context, group, perp_market
        )

    return instructions


def build_consume_events_instructions(
    context: mango.Context,
    group: mango.Group,
    perp_markets: typing.Sequence[mango.PerpMarketDetails],
    event_queues: typing.Sequence[mango.PerpEventQueue],
) -> mango.CombinableInstructions:
    instructions = mango.CombinableInstructions.empty()

    for index, perp_market in enumerate(perp_markets):
        event_queue = event_queues[index]
        if event_queue.address != perp_market.event_queue:
            raise Exception(
                f"Sequence mismatch - perp market {perp_market.address} should have data for event queue {perp_market.event_queue} but was given event queue {event_queue.address}"
            )

        instructions += mango.build_perp_consume_events_instructions(
            context, group, perp_market, event_queue.accounts_to_crank
        )

    return instructions


with mango.ContextBuilder.from_command_line_parameters(args) as context:
    health = mango.HealthCheck()
    wallet = mango.Wallet.from_command_line_parameters_or_raise(args)
    signers: mango.CombinableInstructions = mango.CombinableInstructions.from_wallet(
        wallet
    )
    group = mango.Group.load(context, context.group_address)

    node_bank_addresses: typing.Dict[str, typing.Sequence[PublicKey]] = {}
    quote_rootbank = group.shared_quote.ensure_root_bank(context)
    node_bank_addresses[str(quote_rootbank.address)] = quote_rootbank.node_banks
    for slot in group.slots:
        if slot.base_token_bank is not None:
            base_rootbank = slot.base_token_bank.ensure_root_bank(context)
            node_bank_addresses[str(base_rootbank.address)] = base_rootbank.node_banks

    perp_market_addresses = []
    event_queue_addresses = []
    for slot in group.slots:
        if slot.perp_market is not None:
            perp_market_addresses += [slot.perp_market.address]
            perp_market = mango.PerpMarketDetails.load(
                context, slot.perp_market.address, group
            )
            event_queue_addresses += [perp_market.event_queue]

    perp_market_count = len(perp_market_addresses)
    addresses = [
        group.address,
        *perp_market_addresses,
        *event_queue_addresses,
    ]

    logging.info(f"Keeper running on {perp_market_count} perp markets.")

    run = True
    while run:
        try:
            run_keeper_update: bool = False
            if args.skip_interval is None:
                # No skip interval provided, so we never skip - we just always run the
                # keeper update.
                run_keeper_update = True
            else:
                cache = mango.Cache.load(context, group.cache)
                last_update = max(
                    [
                        *[pc.last_update for pc in cache.price_cache if pc is not None],
                        *[
                            rbc.last_update
                            for rbc in cache.root_bank_cache
                            if rbc is not None
                        ],
                        *[
                            pmc.last_update
                            for pmc in cache.perp_market_cache
                            if pmc is not None
                        ],
                    ]
                )

                now = mango.utc_now()
                threshold = now - timedelta(seconds=args.skip_interval)
                if last_update < threshold:
                    logging.info(
                        f"Last update was {last_update} - running keeper instructions"
                    )
                    run_keeper_update = True
                else:
                    logging.info(
                        f"Last update was {last_update} - no need for us to run keeper instructions"
                    )

            # If run_keeper_update is True it's because the Cache data is stale or we just
            # want to run the keeper update every time.
            if run_keeper_update:
                account_infos = mango.AccountInfo.load_multiple(context, addresses)

                group = mango.Group.parse(
                    account_infos[0],
                    group.name,
                    context.instrument_lookup,
                    context.market_lookup,
                )

                perp_markets = []
                event_queues = []
                for index in range(1, 1 + perp_market_count):
                    perp_market_details = mango.PerpMarketDetails.parse(
                        account_infos[index], group
                    )
                    perp_markets += [perp_market_details]
                    slot = group.slot_by_perp_market_address(
                        perp_market_details.address
                    )
                    event_queue = mango.PerpEventQueue.parse(
                        account_infos[index + perp_market_count],
                        slot.perp_lot_size_converter,
                    )
                    event_queues += [event_queue]

                update_cache_instructions = build_update_cache_instructions(
                    context, group, perp_markets
                )
                update_funding_instructions = build_update_funding_instructions(
                    context, group, node_bank_addresses, perp_markets
                )

                consume_events_instructions = build_consume_events_instructions(
                    context, group, perp_markets, event_queues
                )

                combined = (
                    signers
                    + update_cache_instructions
                    + update_funding_instructions
                    + consume_events_instructions
                )

                if not args.dry_run:
                    signatures = combined.execute(context)
                    mango.output(signatures)
                else:
                    mango.output("Dry run - not running instructions:")
                    mango.output(combined.report(context.client.instruction_reporter))
            health.ping("keeper")
        except KeyboardInterrupt:
            logging.info("Keeper stopping...")
            run = False
        except Exception as exception:
            logging.error(
                f"Keeper caught exception, continuing after pause: {exception}"
            )

        if args.pulse_interval == -1:
            run = False
        else:
            time.sleep(args.pulse_interval)

logging.info("Keeper completed.")
