#!/usr/bin/env python3

import argparse
import logging
import os
import os.path
import sys
import typing

from decimal import Decimal
from solana.publickey import PublicKey

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import mango  # nopep8

parser = argparse.ArgumentParser(
    description="Balance the value of tokens in a Mango Markets group to specific values or percentages."
)
mango.ContextBuilder.add_command_line_parameters(parser)
mango.Wallet.add_command_line_parameters(parser)
parser.add_argument(
    "--account-address",
    type=PublicKey,
    help="address of the specific account to use, if more than one available",
)
parser.add_argument(
    "--target",
    type=mango.parse_fixed_target_balance,
    action="append",
    required=True,
    help="token symbol plus target value, separated by a colon (e.g. 'ETH:2.5')",
)
parser.add_argument(
    "--action-threshold",
    type=Decimal,
    default=Decimal("0.01"),
    help="fraction of total wallet value a trade must be above to be carried out",
)
parser.add_argument(
    "--max-slippage",
    type=Decimal,
    default=Decimal("0.05"),
    help="maximum slippage allowed for the IOC order price",
)
parser.add_argument(
    "--quote-symbol",
    type=str,
    default="USDC",
    help="quote token symbol to use for markets",
)
parser.add_argument(
    "--dry-run",
    action="store_true",
    default=False,
    help="runs as read-only and does not perform any transactions",
)
args: argparse.Namespace = mango.parse_args(parser)

with mango.ContextBuilder.from_command_line_parameters(args) as context:
    wallet: mango.Wallet = mango.Wallet.from_command_line_parameters_or_raise(args)
    action_threshold: Decimal = args.action_threshold
    adjustment_factor: Decimal = args.adjustment_factor
    action_threshold = args.action_threshold
    max_slippage = args.max_slippage
    group = mango.Group.load(context, context.group_address)
    account = mango.Account.load_for_owner_by_address(
        context, wallet.address, group, args.account_address
    )

    logging.info(f"Wallet address: {wallet.address}")

    targets: typing.Sequence[mango.FixedTargetBalance] = args.target
    logging.info(f"Targets: {targets}")

    quote_token = mango.token(context, args.quote_symbol)

    prices: typing.List[mango.InstrumentValue] = []
    oracle_provider: mango.OracleProvider = mango.create_oracle_provider(
        context, "market"
    )
    for target in targets:
        target_token = mango.token(context, target.symbol)
        market_symbol: str = f"serum:{target_token.symbol}/{quote_token.symbol}"
        market = mango.market(context, market_symbol)
        oracle = oracle_provider.oracle_for_market(context, market)
        if oracle is None:
            raise Exception(f"Could not find oracle for market {market_symbol}")
        price = oracle.fetch_price(context)
        prices += [mango.InstrumentValue(target_token, price.mid_price)]
    prices += [mango.InstrumentValue(quote_token, Decimal(1))]

    wallet_balancer = mango.LiveWalletBalancer(
        context,
        wallet,
        account,
        quote_token,
        targets,
        action_threshold,
        max_slippage,
        args.dry_run,
    )
    wallet_balancer.balance(context, prices)

logging.info("Balancing completed.")
