#!/usr/bin/env python
# coding: utf-8


import click
import crayons
import docker
import requests
import sys


REGISTRY_URI = "http://localhost:8000/registry.json"

DOCKER_REGISTRY = "docker.io"


def get_registry():
    return requests.get(REGISTRY_URI).json()


def get_model_dict():
    registry = get_registry()
    return {m["shortname"]: Model(m) for m in registry["models"]}


def get_docker_client():
    client = docker.APIClient()
    return client


class Model(object):

    def __init__(self, model_dict):
        self.__dict__ = model_dict

    @property
    def image_uri(self):
        return "%s/%s:%s" % (self.image["registry"], self.image["image"], self.image["tag"])


def run_model_command(model, command_str, pull=True,
                      stdin=None, stdout=sys.stdout, stderr=sys.stderr):
    """
    Run the given shell command inside a container instantiating the given
    model.
    """
    try:
        model = get_model_dict()[model]
    except KeyError:
        raise click.UsageError(f"Model {model} not found.")

    client = get_docker_client()

    if pull:
        # First pull the image.
        registry = model.image["registry"]
        image, tag = model.image["image"], model.image["tag"]
        click.echo("Pulling latest Docker image for %s:%s." % (image, tag), err=True)
        try:
            image_ret = client.pull(f"{registry}/{image}", tag=tag)
        except docker.errors.NotFound:
            raise RuntimeError("Image not found.")

    container = client.create_container(f"{image}:{tag}", stdin_open=True,
                                        command=command_str)
    client.start(container)

    if stdin is not None:
        # Send file contents to stdin of container.
        in_stream = client.attach_socket(container, params={"stdin": 1, "stream": 1})
        in_stream._sock.send(stdin.read())
        in_stream.close()

    # Stop container and collect results.
    client.stop(container)

    # Collect output.
    container_stdout = client.logs(container, stdout=True, stderr=False)
    container_stderr = client.logs(container, stdout=False, stderr=True)

    client.remove_container(container)
    stdout.buffer.write(container_stdout)
    stderr.buffer.write(container_stderr)


@click.group()
def lm_zoo(): pass


@lm_zoo.command()
@click.option("--short", is_flag=True, default=False,
              help="Output just a list of shortnames rather than a pretty list")
def list(short):
    """
    List language models available in the central repository.
    """
    show_props = [
        ("name", "Full name"),
        ("ref_url", "Reference URL"),
        ("maintainer", "Maintainer"),
        ("datetime", "Last updated"),
    ]

    for model in get_model_dict().values():
        if short:
            click.echo(model.shortname)
        else:
            click.echo(crayons.normal(model.shortname, bold=True))
            click.echo("\t{0} {1}".format(
                crayons.normal("Image URI: ", bold=True),
                model.image_uri))
            for key, label in show_props:
                click.echo("\t" + crayons.normal(label + ": ", bold=True)
                           + getattr(model, key, "None"))


@lm_zoo.command()
@click.argument("model")
@click.argument("in_file", type=click.File("rb"))
def tokenize(model, in_file):
    run_model_command(model, "tokenize /dev/stdin",
                      stdin=in_file)


@lm_zoo.command()
@click.argument("model")
@click.argument("in_file", type=click.File("rb"))
def get_surprisals(model, in_file):
    # TODO document.
    run_model_command(model, "get_surprisals /dev/stdin",
                      stdin=in_file)


@lm_zoo.command()
@click.argument("model")
@click.argument("in_file", type=click.File("rb"))
def unkify(model, in_file):
    run_model_command(model, "unkify /dev/stdin",
                      stdin=in_file)



if __name__ == "__main__":
    lm_zoo()
