#!/usr/bin/env python3

import argparse
import logging
import os
import os.path
import rx
import rx.operators
import sys
import threading
import traceback
import typing

from decimal import Decimal

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import mango  # nopep8

parser = argparse.ArgumentParser(
    description="Run a liquidator for a Mango Markets group."
)
mango.ContextBuilder.add_command_line_parameters(parser)
mango.Wallet.add_command_line_parameters(parser)
parser.add_argument(
    "--throttle-reload-to-seconds",
    type=Decimal,
    default=Decimal(60),
    help="minimum number of seconds between each full margin account reload loop (including time taken processing accounts)",
)
parser.add_argument(
    "--throttle-ripe-update-to-seconds",
    type=Decimal,
    default=Decimal(5),
    help="minimum number of seconds between each ripe update loop (including time taken processing accounts)",
)
parser.add_argument(
    "--target",
    type=mango.parse_target_balance,
    action="append",
    help="token symbol plus target value or percentage, separated by a colon (e.g. 'ETH:2.5')",
)
parser.add_argument(
    "--action-threshold",
    type=Decimal,
    default=Decimal("0.01"),
    help="fraction of total wallet value a trade must be above to be carried out",
)
parser.add_argument(
    "--worthwhile-threshold",
    type=Decimal,
    default=Decimal("0.01"),
    help="value a liquidation must be above to be carried out",
)
parser.add_argument(
    "--adjustment-factor",
    type=Decimal,
    default=Decimal("0.05"),
    help="factor by which to adjust the SELL price (akin to maximum slippage)",
)
parser.add_argument(
    "--notify-liquidations",
    type=mango.parse_notification_target,
    action="append",
    default=[],
    help="The notification target for liquidation events",
)
parser.add_argument(
    "--notify-successful-liquidations",
    type=mango.parse_notification_target,
    action="append",
    default=[],
    help="The notification target for successful liquidation events",
)
parser.add_argument(
    "--notify-failed-liquidations",
    type=mango.parse_notification_target,
    action="append",
    default=[],
    help="The notification target for failed liquidation events",
)
parser.add_argument(
    "--notify-errors",
    type=mango.parse_notification_target,
    action="append",
    default=[],
    help="The notification target for error events",
)
parser.add_argument(
    "--dry-run",
    action="store_true",
    default=False,
    help="runs as read-only and does not perform any transactions",
)
args: argparse.Namespace = mango.parse_args(parser)

handler = mango.NotificationHandler(
    mango.CompoundNotificationTarget(args.notify_errors)
)
handler.setLevel(logging.ERROR)
logging.getLogger().addHandler(handler)


def start_subscriptions(
    context: mango.Context,
    liquidation_processor: mango.LiquidationProcessor,
    fetch_prices: typing.Callable[[typing.Any], typing.Any],
    fetch_accounts: typing.Callable[[typing.Any], typing.Any],
    throttle_reload_to_seconds: Decimal,
    throttle_ripe_update_to_seconds: Decimal,
) -> typing.Tuple[rx.core.typing.Disposable, rx.core.typing.Disposable]:
    liquidation_processor.state = mango.LiquidationProcessorState.STARTING

    logging.info("Starting margin account fetcher subscription")
    account_subscription = (
        rx.interval(float(throttle_reload_to_seconds))
        .pipe(
            rx.operators.observe_on(context.create_thread_pool_scheduler()),
            rx.operators.start_with(-1),
            rx.operators.map(fetch_accounts(context)),
            rx.operators.catch(mango.observable_pipeline_error_reporter),
            rx.operators.retry(),
        )
        .subscribe(
            mango.create_backpressure_skipping_observer(
                on_next=liquidation_processor.update_accounts,
                on_error=mango.log_subscription_error,
            )
        )
    )

    logging.info("Starting price fetcher subscription")
    price_subscription = (
        rx.interval(float(throttle_ripe_update_to_seconds))
        .pipe(
            rx.operators.observe_on(context.create_thread_pool_scheduler()),
            rx.operators.map(fetch_prices(context)),
            rx.operators.catch(mango.observable_pipeline_error_reporter),
            rx.operators.retry(),
        )
        .subscribe(
            mango.create_backpressure_skipping_observer(
                on_next=lambda piped: liquidation_processor.update_prices(
                    piped[0], piped[1]
                ),
                on_error=mango.log_subscription_error,
            )
        )
    )

    return account_subscription, price_subscription


try:
    with mango.ContextBuilder.from_command_line_parameters(args) as context:
        wallet = mango.Wallet.from_command_line_parameters_or_raise(args)
        group = mango.Group.load(context, context.group_address)
        account = mango.Account.load_for_owner_by_address(
            context, wallet.address, group, args.account_address
        )

        action_threshold = args.action_threshold
        worthwhile_threshold = args.worthwhile_threshold
        adjustment_factor = args.adjustment_factor
        throttle_reload_to_seconds = args.throttle_reload_to_seconds
        throttle_ripe_update_to_seconds = args.throttle_ripe_update_to_seconds
        liquidator_name = args.name

        logging.info(f"Wallet address: {wallet.address}")

        group = mango.Group.load(context)
        tokens = [
            token_bank.token for token_bank in group.tokens if token_bank is not None
        ]

        logging.info("Checking wallet accounts.")
        scout = mango.AccountScout()
        report = scout.verify_account_prepared_for_group(context, group, wallet.address)
        logging.info(f"Wallet account report: {report}")
        if report.has_errors:
            raise Exception(
                f"Account '{wallet.address}' is not prepared for group '{group.address}'."
            )

        logging.info("Wallet accounts OK.")

        liquidations_publisher = mango.EventSource[mango.LiquidationEvent]()
        liquidations_publisher.subscribe(
            on_next=mango.CompoundNotificationTarget(args.notify_liquidations).send
        )  # type: ignore[call-arg]

        on_success = mango.FilteringNotificationTarget(
            mango.CompoundNotificationTarget(args.notify_successful_liquidations),
            lambda item: isinstance(item, mango.LiquidationEvent) and item.succeeded,
        )
        liquidations_publisher.subscribe(on_next=on_success.send)  # type: ignore[call-arg]

        on_failed = mango.FilteringNotificationTarget(
            mango.CompoundNotificationTarget(args.notify_failed_liquidations),
            lambda item: isinstance(item, mango.LiquidationEvent)
            and not item.succeeded,
        )
        liquidations_publisher.subscribe(on_next=on_failed.send)  # type: ignore[call-arg]

        # TODO: Add proper liquidator classes here when they're written for V3
        if args.dry_run:
            account_liquidator: mango.AccountLiquidator = mango.NullAccountLiquidator()
        else:
            account_liquidator = mango.NullAccountLiquidator()

        if args.dry_run or (args.target is None) or (len(args.target) == 0):
            wallet_balancer: mango.WalletBalancer = mango.NullWalletBalancer()
        else:
            targets = args.target
            wallet_balancer = mango.LiveWalletBalancer(
                context,
                wallet,
                account,
                group.shared_quote_token,
                targets,
                action_threshold,
                args.adjustment_factor,
                args.dry_run,
            )

        # These (along with `context`) are captured and read by `load_updated_price_details()`.
        group_address = group.address
        oracle_addresses = group.oracles

        def load_updated_price_details() -> typing.Tuple[
            mango.Group, typing.Sequence[mango.InstrumentValue]
        ]:
            oracles = [
                oracle_address
                for oracle_address in oracle_addresses
                if oracle_address is not None
            ]
            all_addresses = [group_address, *oracles]
            all_account_infos = mango.AccountInfo.load_multiple(context, all_addresses)
            group_account_info = all_account_infos[0]
            group = mango.Group.parse_with_context(context, group_account_info)

            # TODO - fetch prices when code available in V3.
            return group, []

        def fetch_prices(
            context: mango.Context,
        ) -> typing.Callable[[typing.Any], typing.Any]:
            def _fetch_prices(_: typing.Any) -> typing.Any:
                with mango.retry_context(
                    "Price Fetch",
                    lambda _: load_updated_price_details(),
                    context.retry_pauses,
                ) as retrier:
                    return retrier.run()

            return _fetch_prices

        def fetch_accounts(
            context: mango.Context,
        ) -> typing.Callable[[typing.Any], typing.Any]:
            def _actual_fetch() -> typing.Sequence[mango.Account]:
                # group = mango.Group.load(context)
                # return mango.Account.load_ripe(context, group)
                return []

            def _fetch_accounts(_: typing.Any) -> typing.Any:
                with mango.retry_context(
                    "Margin Account Fetch",
                    lambda _: _actual_fetch(),
                    context.retry_pauses,
                ) as retrier:
                    return retrier.run()

            return _fetch_accounts

        class LiquidationProcessorSubscriptions:
            def __init__(
                self,
                account: rx.core.typing.Disposable,
                price: rx.core.typing.Disposable,
            ) -> None:
                self.account: rx.core.typing.Disposable = account
                self.price: rx.core.typing.Disposable = price

        liquidation_processor = mango.LiquidationProcessor(
            context,
            liquidator_name,
            account_liquidator,
            wallet_balancer,
            worthwhile_threshold,
        )
        account_subscription, price_subscription = start_subscriptions(
            context,
            liquidation_processor,
            fetch_prices,
            fetch_accounts,
            throttle_reload_to_seconds,
            throttle_ripe_update_to_seconds,
        )

        subscriptions = LiquidationProcessorSubscriptions(
            account=account_subscription, price=price_subscription
        )

        def on_unhealthy(liquidation_processor: mango.LiquidationProcessor) -> None:
            if liquidation_processor.state != mango.LiquidationProcessorState.UNHEALTHY:
                logging.info(
                    f"Ignoring LiquidationProcessor state change - state is: {liquidation_processor.state}"
                )
                return

            logging.warning(
                "Liquidation processor has been marked as unhealthy so recreating subscriptions."
            )
            try:
                subscriptions.account.dispose()
            except Exception as exception:
                logging.warning(
                    f"Ignoring problem disposing of margin account subscription: {exception}"
                )
            try:
                subscriptions.price.dispose()
            except Exception as exception:
                logging.warning(
                    f"Ignoring problem disposing of margin account subscription: {exception}"
                )

            account_subscription, price_subscription = start_subscriptions(
                context,
                liquidation_processor,
                fetch_prices,
                fetch_accounts,
                throttle_reload_to_seconds,
                throttle_ripe_update_to_seconds,
            )
            subscriptions.account = account_subscription
            subscriptions.price = price_subscription

        liquidation_processor.state_change.subscribe(on_next=on_unhealthy)  # type: ignore[call-arg]

        # Wait - don't exit. Exiting will be handled by signals/interrupts.
        waiter = threading.Event()
        waiter.wait()
except KeyboardInterrupt:
    logging.info("Liquidator stopping...")
except Exception as exception:
    logging.critical(
        f"Liquidator stopped because of exception: {exception} - {traceback.format_exc()}"
    )
except:
    logging.critical(
        f"Liquidator stopped because of uncatchable error: {traceback.format_exc()}"
    )
finally:
    logging.info("Liquidator completed.")
