#!/usr/bin/env python

"""
Parallel sync is an experimental feature that leverages the available CPUs
to increase throughput.
This is can be useful in environments that are subject to network latency.

In this scenario, your PG database, Elasticsearch, and PGSync app servers are
on different networks with a delay between request/response time.

The main bottleneck, in this case, is usually roundtrip required for
each database query.
Even with server-side cursors in use, we are still only able to fetch
a limited number of records at a time from the cursor.
The delay in the next cursor fetch can slow down the overall sync
considerably.

The solution here is to perform an initial "fast"/"parallel" sync using this
tool to populate Elasticsearch with the majority of the records.
We can then revert to the native sync process.

This approach uses the Tuple identifier record of the table columns.
Each table contains a hidden system column - "ctid" of type "tid" that
identifies the page record and row number in each block.

We use this to paginate the sync process.
Pagination here Technically implies that we are splitting each paged record
between CPUs.
This allows us to perform Elasticserch bulk inserts in parallel.
The "ctid" is a tuple of (page, row-number) e.g (1, 5) that identifies the
row in a disk page.

This method allows us to fetch all paged row records upfront and split them
into work units amongst the workers. Each chunk of work is defined by the
BLOCK_SIZE and this corresponds to the number of root node records to process
for each worker.
The workers then query for each chunk of work filtering by the page number
and row numbers assigned.
"""

import json
import logging
import multiprocessing
import sys
from dataclasses import dataclass
from typing import Optional

import click
import sqlalchemy as sa

from pgsync.node import Node
from pgsync.settings import BLOCK_SIZE
from pgsync.sync import Sync
from pgsync.utils import get_config, show_settings, Timer

logger = logging.getLogger(__name__)


@dataclass
class Task:
    """Class for holding a task object."""

    ctid: int
    txmin: Optional[int] = None
    txmax: Optional[int] = None


class Worker(multiprocessing.Process):
    def __init__(self, queue: multiprocessing.JoinableQueue, doc: dict):
        multiprocessing.Process.__init__(self)
        self.queue: multiprocessing.JoinableQueue = queue
        self.doc: dict = doc

    def run(self) -> None:
        sync: Sync = Sync(self.doc)
        while True:
            task: Task = self.queue.get()
            if task is None:
                sys.stdout.write(f"{self.name}: Exiting...\n")
                sys.stdout.flush()
                self.queue.task_done()
                break
            sync.es.bulk(
                sync.index,
                sync.sync(ctid=task.ctid, txmin=task.txmin, txmax=task.txmax),
            )
            self.queue.task_done()
        return


def fetch_tasks(doc: dict, block_size: Optional[int] = None) -> list:
    block_size: int = block_size or BLOCK_SIZE
    tasks: list = []
    pages: dict = {}
    sync: Sync = Sync(doc)
    root: Node = sync.tree.build(sync.nodes)
    with Timer("Query tasks..."):
        statement: sa.sql.selectable.Select = sa.select(
            [
                sa.literal_column("1").label("x"),
                sa.literal_column("1").label("y"),
                sa.column("ctid"),
            ]
        ).select_from(root.model)

        i: int = 1
        for _, _, ctid in sync.fetchmany(statement):
            value: list = ctid[0].split(",")
            page: int = int(value[0].replace("(", ""))
            row: int = int(value[1].replace(")", ""))
            pages.setdefault(page, [])
            pages[page].append(row)
            if i % block_size == 0:
                tasks.append(pages)
                pages = {}
            i += 1
    return tasks


def pull(ncpus: int, config: str, verbose: Optional[bool] = False) -> None:
    """Pull data from db using multiprocessing."""
    queue: multiprocessing.JoinableQueue = multiprocessing.JoinableQueue()
    with Timer():
        for doc in json.load(open(config)):
            logger.debug(f"ncpus   : {ncpus}")
            tasks: list = fetch_tasks(doc)
            logger.debug(f"tasks   : {len(tasks)}")
            workers: List[Worker] = [Worker(queue, doc) for _ in range(ncpus)]
            sync: Sync = Sync(doc, verbose=verbose)
            txmin: int = sync.checkpoint
            txmax: int = sync.txid_current
            logger.debug(f"pull txmin: {txmin} - txmax: {txmax}")

            for worker in workers:
                worker.start()
            for ctid in tasks:
                queue.put(Task(ctid, txmin, txmax))
            for _ in range(ncpus):
                queue.put(None)
            queue.join()

            sync.checkpoint: int = txmax or sync.txid_current
            # now sync up to txmax to capture everything we may have missed
            sync.logical_slot_changes(txmin=txmin, txmax=txmax)


@click.command()
@click.option(
    "--config",
    "-c",
    help="Schema config",
    type=click.Path(exists=True),
)
@click.option(
    "--verbose",
    "-v",
    is_flag=True,
    default=False,
    help="Turn on verbosity",
)
@click.option(
    "--ncpus",
    "-n",
    help="Number of cpus",
    type=int,
    default=multiprocessing.cpu_count() * 2,
)
def main(config, ncpus, verbose):

    show_settings()
    config: str = get_config(config)
    pull(ncpus, config, verbose=verbose)


if __name__ == "__main__":
    main()
