#!/usr/bin/env python3
import asyncio
import math
import os
import pprint
import sys

import click
import iterm2
import yaml
from click_option_group import MutuallyExclusiveOptionGroup, optgroup
from iterm2 import Session
from iterm2.color import Color
from iterm2.profile import LocalWriteOnlyProfile


@click.command()
@optgroup.group('General options')
@optgroup.option('--clusters', '-c', 'cli_clusters', multiple=False, help='Comma-separated list of clusters specified in ~/.i2csshrc')
@optgroup.option('--machines', '-m', multiple=False, help='Comma-separated list of hosts')
@optgroup.option('--file', '-f', multiple=False, help='Cluster file (one hostname per line)')
@optgroup.option('--tab-split', '-t', is_flag=True, default=False, help="Split servers/clusters into tabs (group arguments)")
@optgroup.option('--tab-split-nogroup', '-T', is_flag=True, default=False, help="Split servers/clusters into tabs (don't group arguments)")
@optgroup.option('--same-window', '-W', is_flag=True, default=False, help="Use existing window for spawning new tabs")
@optgroup.group('SSH options')
@optgroup.option('--forward-agent', '-A', multiple=False, is_flag=True, help='Enable SSH agent forwarding')
@optgroup.option('--login', '-l', 'username', multiple=False, help='SSH user name')
@optgroup.option('--environment', '-e', multiple=False, help='Send environment vars (comma-separated list, need to start with LC_)')
@optgroup.option('--rank', '-r', is_flag=True, help='Send LC_RANK with the host number as environment variable')
@optgroup.option('--extra', '-X', multiple=True, help='Additional ssh parameters (e.g. -Xi=myidentity.pem)')
@optgroup.option('--gateway', '-g', multiple=False, help='Multihop SSH connection gateway string (e.g. username@gateway) - usually used with -A')
@optgroup.group('iTerm2 options')
@optgroup.option('--fullscreen', '-F', is_flag=True, help='Make the window fullscreen')
@optgroup.option('--broadcast', '-b', is_flag=True, default=False, help='Start with broadcast input (DANGEROUS!)')
@optgroup.option('--nobroadcast', '-nb', is_flag=True, default=False, help='Disable broadcast input')
@optgroup.option('--profile', '-p', multiple=False, help='iTerm2 profile name (default: Default)')
@optgroup.option('--sleep', '-s', multiple=False, type=int, help='Number of seconds to sleep between creating SSH sessions')
@optgroup.option('--shell', '-S', multiple=False, help='Shell to use when spawning the SSH sessions (default: bash)')
@optgroup.option('--direction', '-d', multiple=False, default="column", type=click.Choice(['column', 'row'], case_sensitive=False), help='Direction that new sessions are created (default: column)')
@optgroup.group("Layout", cls=MutuallyExclusiveOptionGroup)
@optgroup.option('--columns', '-C', multiple=False, type=int, help='Number of columns (rows will be calculated)')
@optgroup.option('--rows', '-R', multiple=False, type=int, help='Number of rows (columns will be calculated)')
@click.argument('hosts_or_cluster', nargs=-1, )
def i2cssh(hosts_or_cluster, *_args, **opts):
    """HOSTS: [(username@host [username@host] | username@cluster)]"""

    sanitize_options(opts)

    # Keep track of valid command line options so we can filter out
    # additional options in the config file
    valid_options = list(opts.keys())

    # Load config file and get valid options
    config = read_config()
    if config:
        global_options_from_config = get_global_options_from_config(
            config, valid_options)

    # Because we can split tabs based on cluster/hosts/arguments we first
    # create different groups inside the hosts list. If it turns out we
    # don't need to split tabs, we'll just have one group by flattening
    # the list.

    # List where we eventually store the groups of hosts. Hosts will be
    # dictionaries with all the options necessary to spawn the SSH session
    groups = []

    # If there's only one argument, we assume it's a cluster name (since there's
    # no need to use i2cssh for a single host) and we get hosts from the cluster
    # configuration. Otherwise we construct a list of hosts from the arguments.
    if len(hosts_or_cluster) == 1:
        # Single cluster, so we know there's only one list returned
        groups.append(get_clusters_from_cluster_names(
            [hosts_or_cluster[0]], config, valid_options))
    elif len(hosts_or_cluster) > 1:
        argument_hosts = get_hosts(hosts_or_cluster)

        # If tab_split is defined, all hosts are grouped into a single tab
        # otherwise each host is a separate tab
        if opts.get("tab_split"):
            groups.append(argument_hosts)
        else:
            for host in argument_hosts:
                groups.append([host])

    # The command line option '-c' can also provide clusters, so we need to get the hosts
    # from those clusters as well.
    if opts.get("cli_clusters"):
        clusters = get_clusters_from_cluster_names(
            opts.get("cli_clusters").split(","), config, valid_options)
        for cluster in clusters:
            groups.append(cluster)

    # The command line option '-m' can also provide hosts, so we need to get the hosts
    # from those host strings as well
    if opts.get("machines"):
        groups.append(get_hosts(opts.get("machines").split(",")))

    if filename := opts.get("file"):
        groups.append(get_hosts(get_host_strs_from_file(filename)))

    # Each host will might have additional options based on cluster config
    # or username@host syntax. We need to merge those options with the
    # global options specified in the config file and on the command line.
    #
    # Precedence is: command line > global config > cluster config
    for host in groups:
        if config:
            # Apply global options from config file
            apply_opts(host, global_options_from_config, valid_options)
        # Apply options from command line
        apply_opts(host, opts, valid_options)

    # TODO: all non-boolean cmdline options are passed to hosts, which isn't a problem
    # but maybe not as nice as it should be.

    # Flatten the list if we don't need to split tabs. We create a single
    # group though, so we can still use the same code to calculate geometry
    # in both cases
    if not (opts.get("tab_split") or opts.get("tab_split_nogroup")):
        groups = [[item for sublist in groups for item in sublist]]

    if len(flatten(groups)) == 0:
        click.echo("No hosts found")
        sys.exit(1)

    # Attach a geometry to each group. Here we take the first host in the group
    # since it will have all the options (including geometry) from the config
    groups = [{
        "hosts": hosts,
        "geometry": compute_geometry(len(hosts), hosts[0])
    } for hosts in groups]

    iterm2.run_until_complete(exec_in_iterm(groups, opts))


def get_host_strs_from_file(filename):
    hosts = []
    with open(filename) as f:
        for line in f:
            hosts.append(line.strip())
    return hosts


def sanitize_options(opts):
    # Sanitize options, since we need to be able to distinguish between
    # booleans being set by the user or not set at all, which defaults
    # to False. Through command line options, booleans can only be set
    # to True, while in the config file they can be set to True or False.
    for opt in opts:
        if opts[opt] is False:
            opts[opt] = None


def exec_in_iterm(groups, opts):
    async def inner(connection):
        # Setup iterm API connection
        app = await iterm2.async_get_app(connection)
        window = app.current_window

        # Bail if we're not inside iterm
        if window is None:
            print("No current window")
            return

        broadcast_domains = []
        for i, group in enumerate(groups):

            # Hosts in the group use the same profile, since we can't specify a profile per host
            # in the config file. Because of this, we take the profile for the first host in the
            # group and use that for all hosts in that group.
            profile_name = group["hosts"][0].get("profile", "Default")

            # Same thing for the shell
            shell = f"/usr/bin/env {group['hosts'][0].get('shell', 'bash')} -l "
            lwop = LocalWriteOnlyProfile()
            lwop.set_use_custom_command("Yes")
            lwop.set_command(f'{shell}\n')

            if i == 0 and not opts.get("same_window"):
                window = await window.async_create(connection, profile_name, profile_customizations=lwop)
            else:
                await window.async_create_tab(profile_name, profile_customizations=lwop)

            # Set window to fullscreen if necessary
            if i == 0 and opts.get("fullscreen"):
                await window.async_set_fullscreen(True)

            cols = group["geometry"]["cols"]
            rows = group["geometry"]["rows"]

            if opts.get("direction") == "column":
                vertical = True
                horizontal = False
            else:
                vertical = False
                horizontal = True

            # List to keep track of panes that are created
            panes = []
            # First pane is already there
            pane = window.current_tab.current_session
            panes.append(pane)

            # Split vertically cols-1 times from the last pane
            # This takes care of the first row
            for col in range(1, cols):
                pane = await pane.async_split_pane(vertical, profile_customizations=lwop, profile=profile_name)
                panes.append(pane)

            # For subsequent rows, we first go back to the first
            # pane of the row above, then split horizontally
            for row in range(1, rows):
                first_col_of_last_row = col_row_to_index(0, row-1, cols)
                pane: Session = panes[first_col_of_last_row]

                # Then for each column we split horizontally and jump
                # to the pane in the next column on the previous row
                for col in range(cols):
                    new_pane = await pane.async_split_pane(horizontal, profile_customizations=lwop, profile=profile_name)
                    panes.append(new_pane)
                    next_pane = col_row_to_index(col+1, row-1, cols)
                    pane = panes[next_pane]

            # Add panes for this group to a broadcast domain, if the first host sepcifies
            # broadcast as an option, UNLESS nobroadcast is set as well.
            if not group["hosts"][0].get("nobroadcast") and group["hosts"][0].get("broadcast"):
                broadcast_domains.append(panes)

            for p, pane in enumerate(panes):
                # There might be more panes than hosts
                if p < len(group["hosts"]):
                    env_vars = {}
                    send_env = ""

                    # Rank is just an env var, so set it if we have it
                    if opts.get("rank"):
                        env_vars["LC_RANK"] = str(p)

                    # Split other env vars by comma and put them in a dict
                    if from_env := opts.get("environment"):
                        env_vars.update(dict(s.split("=")
                                        for s in from_env.split(",")))

                    # If we have env vars, we export them and set them in the
                    # ssh options
                    if len(env_vars) > 0:
                        (env_vars_str, send_env) = get_env_vars_str(env_vars)
                        await pane.async_send_text(f"{env_vars_str}\n")

                    ssh_prefix = create_ssh_prefix(opts)
                    host = get_host_str(group["hosts"][p])

                    if s := opts.get("sleep"):
                        await asyncio.sleep(s)

                    await pane.async_send_text(f"unset HISTFILE && {ssh_prefix} {send_env} {host}\n")

                # If we run out of hosts, display an "Unused" pane
                else:
                    profile = await pane.async_get_profile()
                    await pane.async_send_text(f"unset HISTFILE\n")
                    await asyncio.sleep(0.3)
                    await profile.async_set_foreground_color(Color.from_hex("#ff0000"))
                    crs = "\n" * 100
                    await pane.async_send_text(f"stty -isig -icanon -echo && echo -e '{crs}UNUSED' && cat > /dev/null\n")

            # Activate first pane
            await panes[0].async_activate()

        # Enable broadcast input for all groups that require it
        await enable_broadcast(connection, broadcast_domains)

    return inner


async def enable_broadcast(connection, broadcast_domains):
    domains = []
    for panes in broadcast_domains:
        domain = iterm2.broadcast.BroadcastDomain()
        for session in panes:
            domain.add_session(session)
        domains.append(domain)

    await iterm2.async_set_broadcast_domains(connection, domains)


def get_host_str(host):
    """Returns a string username@host if the username is known. Otherwise returns just the hostname"""
    if username := host.get("username"):
        return f"{username}@{host['hostname']}"
    else:
        return host["hostname"]


def get_env_vars_str(env_vars):
    """Returns a string that exports the env vars and a string that sets them in the ssh options"""
    env_vars_str = " ".join([f"export {k}={v};" for k, v in env_vars.items()])
    send_env = "-o SendEnv=" + ",".join([k for k in env_vars.keys()])
    return (env_vars_str, send_env)


def create_ssh_prefix(opts):
    """Returns a string that contains the ssh prefix, including the ssh options"""
    ssh_options = []
    if opts.get("forward_agent"):
        ssh_options.append("-A")

    if e := opts.get("extra"):
        for extra in e:
            if "=" in extra:
                k, v = extra.split("=")
                ssh_options.append(f"-{k} {v}")
            else:
                ssh_options.append(f"-{extra}")

    if g := opts.get("gateway"):
        ssh_options.append(f"-o ProxyCommand=\"ssh -W %h:%p {g}\"")

    return "ssh " + " ".join(ssh_options)


def col_row_to_index(col, row, cols):
    """Calcualate the index of a pane based on row and column"""
    return row * cols + col


def compute_geometry(hosts, opts):
    """Compute the geometry for the given number of hosts and options"""
    if rows := opts.get("rows"):
        cols = math.ceil(hosts / rows)
    elif cols := opts.get("columns"):
        rows = math.ceil(hosts / cols)
    else:
        rows = math.ceil(math.sqrt(hosts))
        cols = math.ceil(hosts / rows)

    # Quick hack: iTerms default window only supports up to 11 rows and 22 columns
    # If we surpass either one, we resort to full screen.
    return {"rows": rows, "cols": cols, "requires_fullscreen": rows > 11 or cols > 22}


def get_hosts(host_strings):
    """Get hosts from a list of hostname or username@hostname strings"""
    hosts = []
    for host_string in host_strings:
        (username, hostname) = parse_hostname(host_string)
        host_opts = {"hostname": hostname, "username": username}
        hosts.append(host_opts)
    return hosts


def get_clusters_from_cluster_names(cluster_names, config, valid_options):
    """
    Get hosts from a list of cluster names. Note that this returns
    a list of hosts grouped by cluster.
    """
    clusters = []
    for cluster_name in cluster_names:
        hosts = []
        (username, cluster_name) = parse_hostname(cluster_name)
        host_names = config["clusters"][cluster_name]["hosts"]
        cluster_options = get_global_options_from_config(
            config["clusters"][cluster_name], valid_options)
        for hostname in host_names:
            # Override host options with cluster options
            host_opts = {"hostname": hostname, "username": username}
            apply_opts(host_opts, cluster_options, valid_options)
            hosts.append(host_opts)
        clusters.append(hosts)
    return clusters


def apply_opts(opts, overrides, valid_options):
    """
    Apply options from the config file and command line.
    This function will recursively apply options to nested lists.
    """
    if isinstance(opts, list):
        for opt in opts:
            apply_opts(opt, overrides, valid_options)
    else:
        for key in valid_options:
            if key in overrides and overrides[key] is not None:
                opts[key] = overrides[key]


def parse_hostname(hostname):
    """Parse a hostname and return a tuple of (username, host)"""
    username = None
    host = hostname
    if "@" in hostname:
        username, host = hostname.split("@")
    return username, host


def get_global_options_from_config(config, valid_options):
    """Get global options from the config file"""
    global_options = {}
    for key in valid_options:
        if key in config:
            global_options[key] = config[key]
    return global_options


def read_config():
    """Read the config file"""
    config_file = os.path.expanduser("~/.i2csshrc")

    # See if there's a config file and read from this first
    if os.path.isfile(config_file):
        with open(config_file, "r") as stream:
            try:
                return yaml.safe_load(stream)
            except yaml.YAMLError as exc:
                print(exc)


def flatten(l):
    """Flatten a list of lists"""
    return [item for sublist in l for item in sublist]


def main():
    i2cssh(prog_name="i2cssh")


if __name__ == '__main__':
    main()
