#!/usr/bin/env python

"""
Parallel sync is an experimental feature that leverages the available
CPU's/Threads to increase throughput.
This is can be useful for environments that have a high network latency.

In this scenario, your PG database, Elasticsearch/OpenSearch, and PGSync
servers are on different networks with a delay between request/response time.
The main bottleneck, in this case, is usually the roundtrip of the database
query.

Even with server-side cursors, we are still only able to fetch
a limited number of records at a time from the cursor.
The delay in the next cursor fetch can slow down the overall sync
considerably.

The solution here is to perform an initial fast/parallel sync 
to populate Elasticsearch/OpenSearch in a single iteration.
When this is complete, we can then continue to run the normal `pgsync`
as a daemon.

This approach uses the Tuple identifier record of the table columns.
Each table contains a system column - "ctid" of type "tid" that
identifies the page record and row number in each block.

We can use this to paginate the sync process.
Pagination here technically implies that we are splitting each paged record
between CPU's/Threads.

This allows us to perform Elasticserch/OpenSearch bulk inserts in parallel.
The "ctid" is a tuple of (page, row-number) e.g (1, 5) that identifies the
row in a disk page.

This method allows us to fetch all paged row records upfront and split them
into work units amongst the workers(threads/cpus).
Each chunk of work is defined by the BLOCK_SIZE and corresponds to the number
of root node records each worker needs to process.

The worker's query for each chunk of work filtering by the page number
and row numbers.
"""

import asyncio
import multiprocessing
import os
import re
import sys
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from dataclasses import dataclass
from queue import Queue
from threading import Thread
from typing import Generator, Optional, Union

import click
import sqlalchemy as sa

from pgsync.settings import BLOCK_SIZE, CHECKPOINT_PATH
from pgsync.sync import Sync
from pgsync.utils import (
    compiled_query,
    config_loader,
    get_config,
    show_settings,
    timeit,
)


def save_ctid(page: int, row: int, name: str) -> None:
    checkpoint_file: str = os.path.join(CHECKPOINT_PATH, f".{name}.ctid")
    with open(checkpoint_file, "w+") as fp:
        fp.write(f"{page},{row}\n")


def read_ctid(name: str) -> None:
    checkpoint_file: str = os.path.join(CHECKPOINT_PATH, f".{name}.ctid")
    if os.path.exists(checkpoint_file):
        with open(checkpoint_file, "r") as fp:
            pairs: str = fp.read().split()[0].split(",")
            page = int(pairs[0])
            row = int(pairs[1])
            return page, row
    return None, None


def logical_slot_changes(
    doc: dict, verbose: bool = False, validate: bool = False
) -> None:
    # now sync up to txmax to capture everything we may have missed
    sync: Sync = Sync(doc, verbose=verbose, validate=validate)
    txmin: int = sync.checkpoint
    txmax: int = sync.txid_current
    sync.logical_slot_changes(txmin=txmin, txmax=txmax)
    sync.checkpoint: int = txmax or sync.txid_current


@dataclass
class Task:
    doc: dict
    verbose: bool = False
    validate: bool = False

    def process(self, task: dict) -> None:
        sync: Sync = Sync(
            self.doc, verbose=self.verbose, validate=self.validate
        )
        sync.tree.build(sync.nodes)
        txmin: int = sync.checkpoint
        txmax: int = sync.txid_current
        sync.search_client.bulk(
            sync.index,
            sync.sync(ctid=task, txmin=txmin, txmax=txmax),
        )
        sys.stdout.write(f"Process pid: {os.getpid()} complete.\n")
        sys.stdout.flush()


@timeit
def fetch_tasks(
    doc: dict,
    block_size: Optional[int] = None,
) -> Generator:
    block_size = block_size or BLOCK_SIZE
    pages: dict = {}
    sync: Sync = Sync(doc)
    page: Optional[int] = None
    row: Optional[int] = None
    name: str = re.sub(
        "[^0-9a-zA-Z_]+", "", f"{sync.database.lower()}_{sync.index}"
    )
    page, row = read_ctid(name=name)
    statement: sa.sql.Select = sa.select(
        [
            sa.literal_column("1").label("x"),
            sa.literal_column("1").label("y"),
            sa.column("ctid"),
        ]
    ).select_from(sync.tree.root.model)

    # filter by Page
    if page:
        statement = statement.where(
            sa.cast(
                sa.func.SPLIT_PART(
                    sa.func.REPLACE(
                        sa.func.REPLACE(
                            sa.cast(sa.column("ctid"), sa.Text),
                            "(",
                            "",
                        ),
                        ")",
                        "",
                    ),
                    ",",
                    1,
                ),
                sa.Integer,
            )
            > page
        )

    # filter by Row
    if row:
        statement = statement.where(
            sa.cast(
                sa.func.SPLIT_PART(
                    sa.func.REPLACE(
                        sa.func.REPLACE(
                            sa.cast(sa.column("ctid"), sa.Text),
                            "(",
                            "",
                        ),
                        ")",
                        "",
                    ),
                    ",",
                    2,
                ),
                sa.Integer,
            )
            > row
        )

    i: int = 1
    for _, _, ctid in sync.fetchmany(statement):
        value: list = ctid[0].split(",")
        page: int = int(value[0].replace("(", ""))
        row: int = int(value[1].replace(")", ""))
        pages.setdefault(page, [])
        pages[page].append(row)
        if i % block_size == 0:
            yield pages
            pages = {}
        i += 1
    yield pages


@timeit
def synchronous(
    tasks: Generator, doc: dict, verbose: bool = False, validate: bool = False
) -> None:
    sys.stdout.write("Synchronous\n")
    sync: Sync = Sync(doc, verbose=verbose, validate=validate)
    sync.tree.build(sync.nodes)
    txmin: int = sync.checkpoint
    txmax: int = sync.txid_current
    index: str = sync.index
    for task in tasks:
        sync.search_client.bulk(
            index,
            sync.sync(ctid=task, txmin=txmin, txmax=txmax),
        )
    logical_slot_changes(doc, verbose=verbose, validate=validate)


@timeit
def multithreaded(
    tasks: Generator,
    doc: dict,
    nprocs: Optional[int] = None,
    verbose: bool = False,
    validate: bool = False,
) -> None:
    sys.stdout.write("Multithreaded\n")

    def worker(sync: Sync, queue: Queue) -> None:
        txmin: int = sync.checkpoint
        txmax: int = sync.txid_current
        while True:
            task: dict = queue.get()
            sync.search_client.bulk(
                sync.index,
                sync.sync(ctid=task, txmin=txmin, txmax=txmax),
            )
            queue.task_done()

    nprocs: int = nprocs or 1
    queue: Queue = Queue()
    sync: Sync = Sync(doc, verbose=verbose, validate=validate)
    sync.tree.build(sync.nodes)

    for _ in range(nprocs):
        thread: Thread = Thread(
            target=worker,
            args=(
                sync,
                queue,
            ),
        )
        thread.daemon = True
        thread.start()
    for task in tasks:
        queue.put(task)

    queue.join()  # block until all tasks are done
    logical_slot_changes(doc, verbose=verbose, validate=validate)


@timeit
def multiprocess(
    tasks: Generator,
    doc: dict,
    nprocs: Optional[int] = None,
    verbose: bool = False,
    validate: bool = False,
) -> None:
    sys.stdout.write("Multiprocess\n")
    task: Task = Task(doc, verbose=verbose, validate=validate)
    with ProcessPoolExecutor(max_workers=nprocs) as executor:
        try:
            list(executor.map(task.process, tasks))
        except Exception as e:
            sys.stdout.write(f"Exception: {e}\n")
    logical_slot_changes(doc, verbose=verbose, validate=validate)


@timeit
def multithreaded_async(
    tasks: Generator,
    doc: dict,
    nprocs: Optional[int] = None,
    verbose: bool = False,
    validate: bool = False,
) -> None:
    sys.stdout.write("Multi-threaded async\n")
    executor: ThreadPoolExecutor = ThreadPoolExecutor(max_workers=nprocs)
    event_loop = asyncio.get_event_loop()
    event_loop.run_until_complete(
        run_tasks(executor, tasks, doc, verbose=verbose, validate=validate)
    )
    logical_slot_changes(doc, verbose=verbose, validate=validate)


@timeit
def multiprocess_async(
    tasks: Generator,
    doc: dict,
    nprocs: Optional[int] = None,
    verbose: bool = False,
    validate: bool = False,
) -> None:
    sys.stdout.write("Multi-process async\n")
    executor: ProcessPoolExecutor = ProcessPoolExecutor(max_workers=nprocs)
    event_loop = asyncio.get_event_loop()
    try:
        event_loop.run_until_complete(
            run_tasks(executor, tasks, doc, verbose=verbose, validate=validate)
        )
    except KeyboardInterrupt:
        pass
    logical_slot_changes(doc, verbose=verbose, validate=validate)


async def run_tasks(
    executor: Union[ThreadPoolExecutor, ProcessPoolExecutor],
    tasks: Generator,
    doc: dict,
    verbose: bool = False,
    validate: bool = False,
) -> None:
    sync: Optional[Sync] = None
    if isinstance(executor, ThreadPoolExecutor):
        # threads can share a common Sync object
        sync = Sync(doc, verbose=verbose, validate=validate)
    event_loop = asyncio.get_event_loop()
    completed, pending = await asyncio.wait(
        [
            event_loop.run_in_executor(
                executor, run_task, task, sync, doc, verbose, validate
            )
            for task in tasks
        ]
    )
    results: list = [task.result() for task in completed]
    print("results: {!r}".format(results))
    print("exiting")


def run_task(
    task: dict,
    sync: Optional[Sync] = None,
    doc: Optional[dict] = None,
    verbose: bool = False,
    validate: bool = False,
) -> int:
    if sync is None:
        sync: Sync = Sync(doc, verbose=verbose, validate=validate)
    sync.tree.build(sync.nodes)
    txmin: int = sync.checkpoint
    txmax: int = sync.txid_current
    sync.search_client.bulk(
        sync.index,
        sync.sync(ctid=task, txmin=txmin, txmax=txmax),
    )
    if len(task) > 0:
        page: int = max(task.keys())
        row: int = max(task[page])
        name: str = re.sub(
            "[^0-9a-zA-Z_]+", "", f"{sync.database.lower()}_{sync.index}"
        )
        save_ctid(page=page, row=row, name=name)

    return 1


@click.command()
@click.option(
    "--config",
    "-c",
    help="Schema config",
    type=click.Path(exists=True),
)
@click.option(
    "--verbose",
    "-v",
    is_flag=True,
    default=False,
    help="Turn on verbosity",
)
@click.option(
    "--nprocs",
    "-n",
    help="Number of threads/process",
    type=int,
    default=multiprocessing.cpu_count() * 2,
)
@click.option(
    "--mode",
    "-m",
    help="Sync mode",
    type=click.Choice(
        [
            "synchronous",
            "multithreaded",
            "multiprocess",
            "multithreaded_async",
            "multiprocess_async",
        ],
        case_sensitive=False,
    ),
    default="multiprocess_async",
)
def main(config, nprocs, mode, verbose):
    """
    TODO:
    - Track progress across cpus/threads
    - Handle KeyboardInterrupt Exception
    """

    show_settings()
    config: str = get_config(config)

    for document in config_loader(config):
        tasks: Generator = fetch_tasks(document)
        if mode == "synchronous":
            synchronous(tasks, document, verbose=verbose)
        elif mode == "multithreaded":
            multithreaded(tasks, document, nprocs=nprocs, verbose=verbose)
        elif mode == "multiprocess":
            multiprocess(tasks, document, nprocs=nprocs, verbose=verbose)
        elif mode == "multithreaded_async":
            multithreaded_async(
                tasks, document, nprocs=nprocs, verbose=verbose
            )
        elif mode == "multiprocess_async":
            multiprocess_async(tasks, document, nprocs=nprocs, verbose=verbose)


if __name__ == "__main__":
    main()
