#!/usr/bin/env python3

import argparse
import logging
import os
import os.path
import sys
import typing

from decimal import Decimal
from solana.publickey import PublicKey

sys.path.insert(0, os.path.abspath(
    os.path.join(os.path.dirname(__file__), "..")))
import mango  # nopep8


def resolve_quantity(token: mango.Token, current_quantity: Decimal, target_quantity: mango.TargetBalance, price: Decimal):
    current_value: Decimal = current_quantity * price
    resolved_value_to_keep: mango.TokenValue = target_quantity.resolve(token, price, current_value)
    return resolved_value_to_keep


parser = argparse.ArgumentParser(
    description="Balance the value of tokens in a Mango Markets account to specific values or percentages.")
mango.ContextBuilder.add_command_line_parameters(parser)
mango.Wallet.add_command_line_parameters(parser)
parser.add_argument("--target", type=mango.parse_target_balance, action="append", required=True,
                    help="token symbol plus target value or percentage, separated by a colon (e.g. 'ETH:2.5')")
parser.add_argument("--max-slippage", type=Decimal, default=Decimal("0.05"),
                    help="maximum slippage allowed for the IOC order price")
parser.add_argument("--action-threshold", type=Decimal, default=Decimal("0.01"),
                    help="fraction of total wallet value a trade must be above to be carried out")
parser.add_argument("--account-address", type=PublicKey,
                    help="address of the specific account to use, if more than one available")
parser.add_argument("--dry-run", action="store_true", default=False,
                    help="runs as read-only and does not perform any transactions")
args: argparse.Namespace = mango.parse_args(parser)

context = mango.ContextBuilder.from_command_line_parameters(args)
wallet = mango.Wallet.from_command_line_parameters_or_raise(args)
action_threshold = args.action_threshold
max_slippage = args.max_slippage
group = mango.Group.load(context, context.group_address)
account = mango.Account.load_for_owner_by_address(context, wallet.address, group, args.account_address)

targets: typing.Sequence[mango.TargetBalance] = args.target
logging.info(f"Targets: {targets}")

if args.dry_run:
    trade_executor: mango.TradeExecutor = mango.NullTradeExecutor()
else:
    trade_executor = mango.ImmediateTradeExecutor(context, wallet, account, max_slippage)

prices: typing.List[mango.TokenValue] = []
oracle_provider: mango.OracleProvider = mango.create_oracle_provider(context, "market")
for basket_token in account.basket:
    if basket_token is not None:
        market_symbol: str = f"{basket_token.token_info.token.symbol}/{group.shared_quote_token.token.symbol}"
        market = context.market_lookup.find_by_symbol(market_symbol)
        if market is None:
            raise Exception(f"Could not find market {market_symbol}")
        oracle = oracle_provider.oracle_for_market(context, market)
        if oracle is None:
            raise Exception(f"Could not find oracle for market {market_symbol}")
        price = oracle.fetch_price(context)
        prices += [mango.TokenValue(basket_token.token_info.token, price.mid_price)]
prices += [mango.TokenValue(group.shared_quote_token.token, Decimal(1))]

account_balancer = mango.LiveAccountBalancer(account, group, trade_executor, targets, action_threshold)
account_balancer.balance(context, prices)

logging.info("Balancing completed.")
