#!/usr/bin/env python

"""
Parallel sync is an experimental feature that leverages the available
CPUs/threads to increase throughput.
This is can be useful for environments that have a high latency.

In this scenario, your PG database, Elasticsearch, and PGSync servers are
on different subnets with a delay between request/response time.

The main bottleneck, in this case, is usually roundtrip of the database query.

Even with server-side cursors, we are still only able to fetch
a limited number of records at a time from the cursor.
The delay in the next cursor fetch can slow down the overall sync
considerably.

The solution here is to perform an initial "fast"/"parallel" sync using this
tool to populate Elasticsearch with the majority of the records.
We can then revert to the native sync process.

This approach uses the Tuple identifier record of the table columns.
Each table contains a hidden system column - "ctid" of type "tid" that
identifies the page record and row number in each block.

We can use this to paginate the sync process.
Pagination here technically implies that we are splitting each paged record
between CPUs/threads.

This allows us to perform Elasticserch bulk inserts in parallel.
The "ctid" is a tuple of (page, row-number) e.g (1, 5) that identifies the
row in a disk page.

This method allows us to fetch all paged row records upfront and split them
into work units amongst the workers. Each chunk of work is defined by the
BLOCK_SIZE and this corresponds to the number of root node records to process
for each worker.
The workers then query for each chunk of work filtering by the page number
and row numbers assigned.
"""

import asyncio
import json
import multiprocessing
import os
import sys
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from dataclasses import dataclass
from queue import Queue
from threading import Thread
from typing import Generator, Optional, Union

import click
import sqlalchemy as sa

from pgsync.node import Node
from pgsync.settings import BLOCK_SIZE
from pgsync.sync import Sync
from pgsync.utils import get_config, show_settings, timeit


def logical_slot_changes(
    doc: dict, verbose: bool = False, validate: bool = False
) -> None:
    # now sync up to txmax to capture everything we may have missed
    sync: Sync = Sync(doc, verbose=verbose, validate=validate)
    txmin: int = sync.checkpoint
    txmax: int = sync.txid_current
    sync.logical_slot_changes(txmin=txmin, txmax=txmax)
    sync.checkpoint: int = txmax or sync.txid_current


@dataclass
class Task:
    doc: dict
    verbose: bool = False
    validate: bool = False

    def process(self, task: dict) -> None:
        sync: Sync = Sync(
            self.doc, verbose=self.verbose, validate=self.validate
        )
        txmin: int = sync.checkpoint
        txmax: int = sync.txid_current
        sync.es.bulk(
            sync.index,
            sync.sync(ctid=task, txmin=txmin, txmax=txmax),
        )
        sys.stdout.write(f"Process pid: {os.getpid()} complete.\n")
        sys.stdout.flush()


@timeit
def fetch_tasks(doc: dict, block_size: Optional[int] = None) -> Generator:
    block_size: int = block_size or BLOCK_SIZE
    pages: dict = {}
    sync: Sync = Sync(doc)
    root: Node = sync.tree.build(sync.nodes)
    statement: sa.sql.Select = sa.select(
        [
            sa.literal_column("1").label("x"),
            sa.literal_column("1").label("y"),
            sa.column("ctid"),
        ]
    ).select_from(root.model)
    i: int = 1
    for _, _, ctid in sync.fetchmany(statement):
        value: list = ctid[0].split(",")
        page: int = int(value[0].replace("(", ""))
        row: int = int(value[1].replace(")", ""))
        pages.setdefault(page, [])
        pages[page].append(row)
        if i % block_size == 0:
            yield pages
            pages = {}
        i += 1
    yield pages


@timeit
def synchronous(
    tasks: Generator, doc: dict, verbose: bool = False, validate: bool = False
) -> None:
    sys.stdout.write("Synchronous\n")
    sync: Sync = Sync(doc, verbose=verbose, validate=validate)
    txmin: int = sync.checkpoint
    txmax: int = sync.txid_current
    index: str = sync.index
    for task in tasks:
        sync.es.bulk(
            index,
            sync.sync(ctid=task, txmin=txmin, txmax=txmax),
        )
    logical_slot_changes(doc, verbose=verbose, validate=validate)


@timeit
def multithreaded(
    tasks: Generator,
    doc: dict,
    nprocs: Optional[int] = None,
    verbose: bool = False,
    validate: bool = False,
) -> None:
    sys.stdout.write("Multithreaded\n")

    def worker(sync: Sync, queue: Queue) -> None:
        txmin: int = sync.checkpoint
        txmax: int = sync.txid_current
        while True:
            task: dict = queue.get()
            sync.es.bulk(
                sync.index,
                sync.sync(ctid=task, txmin=txmin, txmax=txmax),
            )
            queue.task_done()

    nprocs: int = nprocs or 1
    queue: Queue = Queue()
    sync: Sync = Sync(doc, verbose=verbose, validate=validate)
    for _ in range(nprocs):
        thread: Thread = Thread(
            target=worker,
            args=(
                sync,
                queue,
            ),
        )
        thread.daemon = True
        thread.start()
    for task in tasks:
        queue.put(task)

    queue.join()  # block until all tasks are done

    logical_slot_changes(doc, verbose=verbose, validate=validate)


@timeit
def multiprocess(
    tasks: Generator,
    doc: dict,
    nprocs: Optional[int] = None,
    verbose: bool = False,
    validate: bool = False,
) -> None:
    sys.stdout.write("Multiprocess\n")
    task: Task = Task(doc, verbose=verbose, validate=validate)
    with ProcessPoolExecutor(max_workers=nprocs) as executor:
        try:
            list(executor.map(task.process, tasks))
        except Exception as e:
            sys.stdout.write(f"Exception: {e}\n")

    logical_slot_changes(doc, verbose=verbose, validate=validate)


@timeit
def multithreaded_async(
    tasks: Generator,
    doc: dict,
    nprocs: Optional[int] = None,
    verbose: bool = False,
    validate: bool = False,
) -> None:
    sys.stdout.write("Multi-threaded async\n")
    executor: ThreadPoolExecutor = ThreadPoolExecutor(max_workers=nprocs)
    event_loop = asyncio.get_event_loop()
    try:
        event_loop.run_until_complete(
            run_tasks(executor, tasks, doc, verbose=verbose, validate=validate)
        )
    finally:
        event_loop.close()

    logical_slot_changes(doc, verbose=verbose, validate=validate)


@timeit
def multiprocess_async(
    tasks: Generator,
    doc: dict,
    nprocs: Optional[int] = None,
    verbose: bool = False,
    validate: bool = False,
) -> None:
    sys.stdout.write("Multi-process async\n")
    executor: ProcessPoolExecutor = ProcessPoolExecutor(max_workers=nprocs)
    event_loop = asyncio.get_event_loop()
    try:
        event_loop.run_until_complete(
            run_tasks(executor, tasks, doc, verbose=verbose, validate=validate)
        )
    finally:
        event_loop.close()

    logical_slot_changes(doc, verbose=verbose, validate=validate)


async def run_tasks(
    executor: Union[ThreadPoolExecutor, ProcessPoolExecutor],
    tasks: Generator,
    doc: dict,
    verbose: bool = False,
    validate: bool = False,
) -> None:
    event_loop = asyncio.get_event_loop()
    if isinstance(executor, ThreadPoolExecutor):
        # threads can share a common Sync object
        sync: Sync = Sync(doc, verbose=verbose, validate=validate)
        tasks = [
            event_loop.run_in_executor(
                executor, run_task, task, sync, None, verbose, validate
            )
            for task in tasks
        ]
    else:
        tasks = [
            event_loop.run_in_executor(
                executor, run_task, task, None, doc, verbose, validate
            )
            for task in tasks
        ]
    completed, pending = await asyncio.wait(tasks)
    results = [task.result() for task in completed]
    print("results: {!r}".format(results))
    print("exiting")


def run_task(
    task: dict,
    sync: Optional[Sync] = None,
    doc: Optional[dict] = None,
    verbose: bool = False,
    validate: bool = False,
) -> int:
    if sync is None:
        sync: Sync = Sync(doc, verbose=verbose, validate=validate)
    txmin: int = sync.checkpoint
    txmax: int = sync.txid_current
    sync.es.bulk(
        sync.index,
        sync.sync(ctid=task, txmin=txmin, txmax=txmax),
    )
    print("run_task complete")
    return 1


@click.command()
@click.option(
    "--config",
    "-c",
    help="Schema config",
    type=click.Path(exists=True),
)
@click.option(
    "--verbose",
    "-v",
    is_flag=True,
    default=False,
    help="Turn on verbosity",
)
@click.option(
    "--nprocs",
    "-n",
    help="Number of threads/process",
    type=int,
    default=multiprocessing.cpu_count() * 2,
)
@click.option(
    "--mode",
    "-m",
    help="Sync mode",
    type=click.Choice(
        [
            "synchronous",
            "multithreaded",
            "multiprocess",
            "multithreaded_async",
            "multiprocess_async",
        ],
        case_sensitive=False,
    ),
    default="multiprocess_async",
)
def main(config, nprocs, mode, verbose):
    """
    TODO:
    - Track progress across cpus/threads
    - Save ctid
    - Handle KeyboardInterrupt
    """

    show_settings()
    config: str = get_config(config)

    for doc in json.load(open(config)):
        tasks: Generator = fetch_tasks(doc)
        if mode == "synchronous":
            synchronous(tasks, doc, verbose=verbose)
        elif mode == "multithreaded":
            multithreaded(tasks, doc, nprocs=nprocs, verbose=verbose)
        elif mode == "multiprocess":
            multiprocess(tasks, doc, nprocs=nprocs, verbose=verbose)
        elif mode == "multithreaded_async":
            multithreaded_async(tasks, doc, nprocs=nprocs, verbose=verbose)
        elif mode == "multiprocess_async":
            multiprocess_async(tasks, doc, nprocs=nprocs, verbose=verbose)


if __name__ == "__main__":
    main()
