#!/usr/bin/env python3

"""
This utility connects to a running TrustGraph through the API and creates
a knowledge core from the data streaming through the processing queues.
For completeness of data, tg-save-kg-core should be initiated before data
loading takes place.  The default output  format, msgpack should be used.
JSON output format is also available - msgpack produces a more compact
representation, which is also more performant to load.
"""

import aiohttp
import asyncio
import msgpack
import json
import sys
import argparse
import os
import signal

class Running:
    def __init__(self): self.running = True
    def get(self): return self.running
    def stop(self): self.running = False

async def fetch_ge(running, queue, user, collection, url):

    async with aiohttp.ClientSession() as session:

        async with session.ws_connect(f"{url}stream/graph-embeddings") as ws:

            while running.get():

                try:
                    msg = await asyncio.wait_for(ws.receive(), 1)
                except:
                    continue

                if msg.type == aiohttp.WSMsgType.TEXT:

                    data = msg.json()

                    if user:
                        if data["metadata"]["user"] != user:
                            continue

                    if collection:
                        if data["metadata"]["collection"] != collection:
                            continue

                    await queue.put([
                        "ge",
                        {
                            "m": {
                                "i": data["metadata"]["id"], 
                                "m": data["metadata"]["metadata"],
                                "u": data["metadata"]["user"],
                                "c": data["metadata"]["collection"],
                            },
                            "v": data["vectors"],
                            "e": data["entity"],
                        }
                    ])
                if msg.type == aiohttp.WSMsgType.ERROR:
                    print("Error")
                    break

async def fetch_triples(running, queue, user, collection, url):

    async with aiohttp.ClientSession() as session:

        async with session.ws_connect(f"{url}stream/triples") as ws:

            while running.get():

                try:
                    msg = await asyncio.wait_for(ws.receive(), 1)
                except:
                    continue

                if msg.type == aiohttp.WSMsgType.TEXT:

                    data = msg.json()

                    if user:
                        if data["metadata"]["user"] != user:
                            continue

                    if collection:
                        if data["metadata"]["collection"] != collection:
                            continue

                    await queue.put((
                        "t",
                        {
                            "m": {
                                "i": data["metadata"]["id"], 
                                "m": data["metadata"]["metadata"],
                                "u": data["metadata"]["user"],
                                "c": data["metadata"]["collection"],
                            },
                            "t": data["triples"],
                        }
                    ))
                if msg.type == aiohttp.WSMsgType.ERROR:
                    print("Error")
                    break

ge_counts = 0
t_counts = 0

async def stats(running):

    global t_counts
    global ge_counts

    while running.get():

        await asyncio.sleep(2)

        print(
            f"Graph embeddings: {ge_counts:10d}  Triples: {t_counts:10d}"
        )

async def output(running, queue, path, format):

    global t_counts
    global ge_counts
    
    with open(path, "wb") as f:

        while running.get():

            try:
                msg = await asyncio.wait_for(queue.get(), 0.5)
            except:
                # Hopefully it's TimeoutError.  Annoying to match since
                # it changed in 3.11.
                continue

            if format == "msgpack":
                f.write(msgpack.packb(msg, use_bin_type=True))
            else:
                f.write(json.dumps(msg).encode("utf-8"))

            if msg[0] == "t":
                t_counts += 1
            else:
                if msg[0] == "ge":
                    ge_counts += 1

    print("Output file closed")

async def run(running, **args):

    q = asyncio.Queue()

    ge_task = asyncio.create_task(
        fetch_ge(
            running=running,
            queue=q, user=args["user"], collection=args["collection"],
            url=args["url"] + "api/v1/"
        )
    )

    triples_task = asyncio.create_task(
        fetch_triples(
            running=running, queue=q,
            user=args["user"], collection=args["collection"],
            url=args["url"] + "api/v1/"
        )
    )

    output_task = asyncio.create_task(
        output(
            running=running, queue=q,
            path=args["output_file"], format=args["format"],
        )
        
    )

    stats_task = asyncio.create_task(stats(running))

    await output_task
    await triples_task
    await ge_task
    await stats_task

    print("Exiting")

async def main(running):
    
    parser = argparse.ArgumentParser(
        prog='tg-save-kg-core',
        description=__doc__,
    )

    default_url = os.getenv("TRUSTGRAPH_API", "http://localhost:8088/")
    default_user = "trustgraph"
    collection = "default"

    parser.add_argument(
        '-u', '--url',
        default=default_url,
        help=f'TrustGraph API URL (default: {default_url})',
    )

    parser.add_argument(
        '-o', '--output-file',
        # Make it mandatory, difficult to over-write an existing file
        required=True,
        help=f'Output file'
    )

    parser.add_argument(
        '--format',
        default="msgpack",
        choices=["msgpack", "json"],
        help=f'Output format (default: msgpack)',
    )

    parser.add_argument(
        '--user',
        help=f'User ID to filter on (default: no filter)'
    )

    parser.add_argument(
        '--collection',
        help=f'Collection ID to filter on (default: no filter)'
    )

    args = parser.parse_args()

    await run(running, **vars(args))

running = Running()

def interrupt(sig, frame):
    running.stop()
    print('Interrupt')

signal.signal(signal.SIGINT, interrupt)

asyncio.run(main(running))

