#!/usr/bin/env python3

from quickgpt import QuickGPT
from quickgpt import __version__
from quickgpt.thread.messagetypes import *

import argparse
import json
import os
import sys

if __name__ == "__main__":
    # Set up storage
    config_path = os.path.expanduser("~/.config/quickgpt/threads")

    if not os.path.exists(config_path):
        os.makedirs(config_path)

    # Parse cmd arguments with argparse
    parser = argparse.ArgumentParser(
        prog = "quickgpt",
        description = "Interactive command line tool to access ChatGPT")

    parser.add_argument("-k", "--api-key",
        help="Specify an API key to use with OpenAI",
        dest="api_key")

    parser.add_argument("-t", "--thread",
        help="Recall a previous conversation, or start a new one with the provided identifer")

    parser.add_argument("-p", "--prompt",
        help="Specify the initial prompt",
        default="Assist the user with their questions.")

    parser.add_argument("-l", "--list",
        help="Lists saved threads",
        action="store_true")

    parser.add_argument("-n", "--no-initial-prompt",
        help="Disables the initial prompt, and uses the User's first input as the prompt",
        action="store_true",
        dest="no_prompt")

    parser.add_argument("-i", "--stdin",
        help="Takes a single prompt from stdin, and returns the output via stdout",
        action="store_true")

    parser.add_argument("-v", "--version",
        help="Returns the version of the QuickGPT library (and this command)",
        action="store_true")

    args = parser.parse_args()

    if args.version:
        print("quickGPT %s" % __version__)
        sys.exit(0)

    if args.list:
        thread_files = os.listdir(config_path)

        print("%s threads" % len(thread_files))

        for thread_file in thread_files:
            print("* %s" % thread_file)

        sys.exit(0)

    # If --api-key is specified, use this instead of the environment variable
    api_key = args.api_key

    # Create a chat instance
    chat = QuickGPT(api_key=api_key)

    # Start a new conversational thread
    if args.thread:
        thread_path = os.path.join(config_path, "%s.json" % args.thread)

        if os.path.exists(thread_path):
            with open(thread_path, "r") as f:
                thread_obj = json.loads(f.read())

            thread = chat.restore_thread(thread_obj)
        else:
            thread = chat.new_thread()
            thread.id = args.thread
    else:
        thread = chat.new_thread()

    def save():
        thread_path = os.path.join(config_path, "%s.json" % thread.id)
        with open(thread_path, "w") as f:
            f.write(json.dumps(thread.serialize()))

    # Initial prompt
    # Only run if we're not recalling an existing thread
    if len(thread) > 0:
        for message in thread.thread:
            print(message)
    else:
        if not args.no_prompt:
            thread.feed(
                System(args.prompt)
            )

    if args.stdin:
        prompt = sys.stdin.read(-1)

        thread.feed(
            User(prompt)
        )

        response = thread.run()

        print(response.message)

        sys.exit(0)

    while True:
        if len(thread) > 0:
            response = thread.run()
            print(response)
            print("-" * os.get_terminal_size().columns)

            save()

        while True:
            try:
                prompt = input("<You> ")
                break
            except KeyboardInterrupt:
                print("Bye!")
                sys.exit(0)
            except EOFError:
                print("")
                pass

        print("-" * os.get_terminal_size().columns)

        thread.feed(
            User(prompt)
        )
