#!/usr/bin/env python3

"""This utility takes a knowledge core and loads it into a running TrustGraph
through the API.  The knowledge core should be in msgpack format, which is the
default format produce by tg-save-kg-core.
"""

import aiohttp
import asyncio
import msgpack
import json
import sys
import argparse
import os
import signal

class Running:
    def __init__(self): self.running = True
    def get(self): return self.running
    def stop(self): self.running = False

ge_counts = 0
t_counts = 0

async def load_ge(running, queue, url):

    global ge_counts 

    async with aiohttp.ClientSession() as session:

        async with session.ws_connect(f"{url}load/graph-embeddings") as ws:

            while running.get():

                try:
                    msg = await asyncio.wait_for(queue.get(), 1)

                    # End of load
                    if msg is None:
                        break

                except:
                    # Hopefully it's TimeoutError.  Annoying to match since
                    # it changed in 3.11.
                    continue

                msg = {
                    "metadata": {
                        "id": msg["m"]["i"], 
                        "metadata": msg["m"]["m"],
                        "user": msg["m"]["u"],
                        "collection": msg["m"]["c"],
                    },
                    "vectors": msg["v"],
                    "entity": msg["e"],
                }

                try:
                    await ws.send_json(msg)
                except Exception as e:
                    print(e)

                ge_counts += 1

async def load_triples(running, queue, url):

    global t_counts 

    async with aiohttp.ClientSession() as session:

        async with session.ws_connect(f"{url}load/triples") as ws:

            while running.get():

                try:
                    msg = await asyncio.wait_for(queue.get(), 1)

                    # End of load
                    if msg is None:
                        break

                except:
                    # Hopefully it's TimeoutError.  Annoying to match since
                    # it changed in 3.11.
                    continue

                msg ={
                    "metadata": {
                        "id": msg["m"]["i"], 
                        "metadata": msg["m"]["m"],
                        "user": msg["m"]["u"],
                        "collection": msg["m"]["c"],
                    },
                    "triples": msg["t"],
                }

                try:
                    await ws.send_json(msg)
                except Exception as e:
                    print(e)

                t_counts += 1

async def stats(running):

    global t_counts
    global ge_counts

    while running.get():

        await asyncio.sleep(2)

        print(
            f"Graph embeddings: {ge_counts:10d}  Triples: {t_counts:10d}"
        )

async def loader(running, ge_queue, t_queue, path, format, user, collection):

    if format == "json":

        raise RuntimeError("Not implemented")

    else:

        with open(path, "rb") as f:

            unpacker = msgpack.Unpacker(f, raw=False)

            while running.get():

                try:
                    unpacked = unpacker.unpack()
                except:
                    break

                if user:
                    unpacked["metadata"]["user"] = user

                if collection:
                    unpacked["metadata"]["collection"] = collection

                if unpacked[0] == "t":
                    qtype = t_queue
                else:
                    if unpacked[0] == "ge":
                        qtype = ge_queue

                while running.get():

                    try:
                        await asyncio.wait_for(qtype.put(unpacked[1]), 0.5)

                        # Successful put message, move on
                        break

                    except:
                        # Hopefully it's TimeoutError.  Annoying to match since
                        # it changed in 3.11.
                        continue

                if not running.get(): break

        # Put 'None' on end of queue to finish
        while running.get():

            try:
                await asyncio.wait_for(t_queue.put(None), 1)

                # Successful put message, move on
                break

            except:
                # Hopefully it's TimeoutError.  Annoying to match since
                # it changed in 3.11.
                continue

        # Put 'None' on end of queue to finish
        while running.get():

            try:
                await asyncio.wait_for(ge_queue.put(None), 1)

                # Successful put message, move on
                break

            except:
                # Hopefully it's TimeoutError.  Annoying to match since
                # it changed in 3.11.
                continue

async def run(running, **args):

    # Maxsize on queues reduces back-pressure so tg-load-kg-core doesn't
    # grow to eat all memory
    ge_q = asyncio.Queue(maxsize=10)
    t_q = asyncio.Queue(maxsize=10)

    load_task = asyncio.create_task(
        loader(
            running=running,
            ge_queue=ge_q, t_queue=t_q,
            path=args["input_file"], format=args["format"],
            user=args["user"], collection=args["collection"],
        )
        
    )

    ge_task = asyncio.create_task(
        load_ge(
            running=running,
            queue=ge_q, url=args["url"] + "api/v1/"
        )
    )

    triples_task = asyncio.create_task(
        load_triples(
            running=running,
            queue=t_q, url=args["url"] + "api/v1/"
        )
    )

    stats_task = asyncio.create_task(stats(running))

    await triples_task
    await ge_task

    running.stop()

    await load_task
    await stats_task

async def main(running):
    
    parser = argparse.ArgumentParser(
        prog='tg-load-kg-core',
        description=__doc__,
    )

    default_url = os.getenv("TRUSTGRAPH_API", "http://localhost:8088/")
    default_user = "trustgraph"
    collection = "default"

    parser.add_argument(
        '-u', '--url',
        default=default_url,
        help=f'TrustGraph API URL (default: {default_url})',
    )

    parser.add_argument(
        '-i', '--input-file',
        # Make it mandatory, difficult to over-write an existing file
        required=True,
        help=f'Output file'
    )

    parser.add_argument(
        '--format',
        default="msgpack",
        choices=["msgpack", "json"],
        help=f'Output format (default: msgpack)',
    )

    parser.add_argument(
        '--user',
        help=f'User ID to load as (default: from input)'
    )

    parser.add_argument(
        '--collection',
        help=f'Collection ID to load as (default: from input)'
    )

    args = parser.parse_args()

    await run(running, **vars(args))

running = Running()

def interrupt(sig, frame):
    running.stop()
    print('Interrupt')

signal.signal(signal.SIGINT, interrupt)

asyncio.run(main(running))

