#!/usr/bin/env python3

"""Reconcile and aggregate results."""

import os
from re import T
import shutil
import logging
import json

import click

from anvil.util.reconciler import DEFAULT_CONSORTIUMS, DEFAULT_OUTPUT_PATH, DEFAULT_NAMESPACE, aggregate
from anvil.terra.reconciler import Entities
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(filename)s %(levelname)-8s %(message)s')


@click.group()
@click.pass_context
def cli(ctx):
    """Set up context, main entrypoint."""
    # ensure that ctx.obj exists and is a dict
    # in case we want to eventually chain these commands together into a pipeline
    ctx.ensure_object(dict)


@cli.command('clean')
@click.option('--output_path', default=DEFAULT_OUTPUT_PATH, help=f'output path default={DEFAULT_OUTPUT_PATH}')
def cleaner(output_path):
    """Remove sqlite work databases, json files, etc."""
    def delete_file(db_file):
        """Remove files."""
        try:
            os.unlink(db_file)
            logging.getLogger(__name__).info(f"Dropped {db_file}")
        except FileNotFoundError:
            logging.getLogger(__name__).warning(f"{db_file} FileNotFound")
        except Exception as e:
            logging.getLogger(__name__).error(f"{db_file} {e}")

    def delete_dir(dir_path):
        """Remove dir."""
        try:
            shutil.rmtree(dir_path)
            logging.getLogger(__name__).info(f"Dropped {dir_path}")
        except FileNotFoundError:
            logging.getLogger(__name__).warning(f"{dir_path} FileNotFound")
        except Exception as e:
            logging.getLogger(__name__).error(f"{dir_path} {e}")

    delete_file(f'{output_path}/data_dashboard.json')
    delete_file(f'{output_path}/data_dashboard.tsv')
    delete_file(f'{output_path}/gen3-drs.sqlite')
    delete_file(f'{output_path}/spreadsheet.json')
    delete_file(f'{output_path}/terra.sqlite')
    delete_file(f'{output_path}/terra-graph.sqlite')
    delete_file(f'{output_path}/pyanvil-cache.sqlite')
    delete_file(f'{output_path}/qa-report.md')
    delete_file(f'{output_path}/terra_summary.json')
    # TODO refactor cache to configure path
    delete_file('/tmp/pyanvil-cache.sqlite')
    for consortium in DEFAULT_CONSORTIUMS:
        delete_dir(f'{output_path}/{consortium[0]}')


@cli.command('extract')
@click.option('--user_project', default=os.environ.get('GOOGLE_PROJECT', None), help=f'Google billing project. default={os.environ.get("GOOGLE_PROJECT", None)}')
@click.option('--namespace', default=DEFAULT_NAMESPACE, help=f'Terra namespace default={DEFAULT_NAMESPACE}')
@click.option('--consortiums', type=(str, str), default=DEFAULT_CONSORTIUMS, multiple=True, help=f'<Name Regexp> e.g "CCDG AnVIL_CCDG.*" default {DEFAULT_CONSORTIUMS}')
@click.option('--output_path', default=DEFAULT_OUTPUT_PATH, help=f'output path default={DEFAULT_OUTPUT_PATH}')
@click.option('--pfb_path', required=True, help='pfb path (gen3 avro extract)')
def extractor(user_project, namespace, consortiums, output_path, pfb_path):
    """Harvest all workspaces, return list of workspace_name. Create detailed sqlite graph and summary dashboard."""
    logging.getLogger(__name__).info("Starting aggregation for all specified AnVIL workspaces this will take several minutes.")
    logging.getLogger(__name__).info(f"Reading from consortiums {consortiums}; pfb_path {pfb_path}")
    logging.getLogger(__name__).info(f"Writing to output_path {output_path}")
    assert user_project, "Please provide --user_project (or set GOOGLE_PROJECT)"
    dashboard_output_path = f"{output_path}/data_dashboard.json"
    terra_output_path = f"{output_path}/terra.sqlite"
    drs_output_path = f"{output_path}/gen3-drs.sqlite"

    with open(dashboard_output_path, 'w') as outs:
        views = [v for v in aggregate(namespace=DEFAULT_NAMESPACE,
                 user_project=user_project,
                 consortium=consortiums, avro_path=pfb_path,
                 terra_output_path=terra_output_path,
                 drs_output_path=drs_output_path)]
        json.dump({
            'projects': [v for v in views if 'problems' in v],
            'consortiums': [v for v in views if 'problems' not in v]
        }, outs)

    assert len(views) > 0, "f{consortiums} matched no workspaces"
    assert os.path.isfile(dashboard_output_path), f"{dashboard_output_path} should exist."
    logging.getLogger(__name__).info(f"Wrote summary to {dashboard_output_path}")

    assert os.path.isfile(terra_output_path), f"{terra_output_path} should exist."
    entities = Entities(terra_output_path=terra_output_path, user_project=user_project)
    entities.index()
    # print([workspace.name for workspace in entities.get_by_name('workspace')])
    logging.getLogger(__name__).info(f"Wrote work database to {terra_output_path}")


@cli.command('report')
@click.option('--output_path', default=DEFAULT_OUTPUT_PATH, show_default=True, help='output path.')
@click.option('--user_project', default=os.environ.get('GOOGLE_PROJECT', None), show_default=True, help='Google billing project.')
def reporter(output_path, user_project):
    """Reconcile and report on harvested workspaces."""
    terra_output_path = f"{output_path}/terra.sqlite"
    dashboard_output_path = f"{output_path}/data_dashboard.json"
    drs_output_path = f"{output_path}/gen3-drs.sqlite"

    logging.getLogger(__name__).info("Starting reporting for all extracted AnVIL workspaces this will take several minutes.")
    entities = Entities(terra_output_path=terra_output_path, user_project=user_project)
    workspace_names = [workspace.name for workspace in entities.get_by_name('workspace')]
    logging.getLogger(__name__).info(f"Reporting on {len(workspace_names)} workspaces")

    from datetime import date, datetime
    import json
    import os
    from anvil.util.reconciler import flatten
    import pandas as pd
    from tabulate import tabulate
    import sqlite3

    def json_serial(obj):
        """JSON serializer for objects not serializable by default json code."""
        if isinstance(obj, (datetime, date)):
            return obj.isoformat()
        raise TypeError("Type %s not serializable" % type(obj))

    logging.getLogger(__name__).info(f"Writing report to {output_path}/qa-report.md")
    report_file = open(f'{output_path}/qa-report.md', 'w')

    # validate output summary and
    assert os.path.isfile(dashboard_output_path), f"dashboard should exist {dashboard_output_path}"
    with open(dashboard_output_path, 'r') as inputs:
        dashboard_data = json.load(inputs)

    # Flatten dashboard into tsv

    (flattened, column_names) = flatten(dashboard_data['projects'])
    df = pd.DataFrame(flattened)
    df.columns = column_names
    # Print the data  (all rows, all columns)
    pd.set_option('display.max_rows', None)
    pd.set_option('display.max_columns', None)
    # export create a tsv from dataframe
    df.to_csv(f"{output_path}/data_dashboard.tsv", sep="\t")
    logging.getLogger(__name__).info(f"Wrote {output_path}/data_dashboard.tsv")

    print("# Dashboard", file=report_file)
    print(tabulate(df, headers='keys', tablefmt='github'), file=report_file)

    """
    ## summarize terra exceptions
    > Extract the list of data transformation problems encountered [see more on dashboard exceptions](https://github.com/anvilproject/client-apis/wiki/dashboard-exceptions)
    """
    _projects = [project for project in dashboard_data['projects'] if 'problems' in project]
    flattened = []
    problems = set([problem for project in _projects for problem in project['problems']])
    for problem in problems:
        projects = [project['project_id'] for project in _projects if problem in project['problems']]
        flattened.append([problem, ','.join(projects)])

    print("# Exceptions", file=report_file)
    if len(flattened) > 0:
        # Print the data  (all rows, all columns)

        df = pd.DataFrame(flattened)
        df.columns = ['problem', 'affected_workspaces']
        print(tabulate(df, headers='keys', tablefmt='github'), file=report_file)
    else:
        print("No workspaces have exceptions!", file=report_file)

    print("# Consistent workspaces", file=report_file)
    # list consistent workspaces
    if len([project['project_id'] for project in _projects if len(project['problems']) == 0]) == 0:
        print("None", file=report_file)
    else:
        df = pd.DataFrame([project['project_id'] for project in _projects if len(project['problems']) == 0])
        df.columns = ['workspace']
        print(tabulate(df, headers='keys', tablefmt='github'), file=report_file)

    # Issues/Questions arising from Gen3 PFB
    # create
    def summarize_workspaces():
        """Aggregate harvested workspaces."""
        entities = Entities(terra_output_path=terra_output_path, user_project=user_project)
        # created sql indices
        entities.index()
        terra_summary = f"{output_path}/terra_summary.json"
        emitter = open(terra_summary, "w")
        for workspace in entities.get_by_name('workspace'):
            for subject in workspace.subjects:
                for sample in subject.samples:
                    for property, blob in sample.blobs.items():
                        json.dump(
                            {
                                "workspace_id": workspace.id,
                                "subject_id": subject.id,
                                "sample_id": sample.id,
                                "blob": blob['name'],
                            },
                            emitter,
                            separators=(',', ':')
                        )
                        emitter.write('\n')
        emitter.close()
        logging.getLogger(__name__).info(f"Wrote summary to {terra_summary}")

    summarize_workspaces()

    conn = sqlite3.connect(drs_output_path)
    cur = conn.cursor()

    #
    # load the terra dashboard summary into db
    #
    cur.executescript("""
    --
    drop table if exists terra_details ;
    CREATE TABLE IF NOT EXISTS terra_details (
        workspace_id text,
        subject_id text,
        sample_id text,
        blob text
    );
    """)

    conn.commit()

    logging.info(f"created table {drs_output_path}")
    logging.info(f"loading from  {output_path}/terra_summary.json")
    with open(f"{output_path}/terra_summary.json", 'rb') as fo:
        for line in fo.readlines():
            record = json.loads(line)
            cur.execute("REPLACE into terra_details values (?, ?, ?, ?);", (record['workspace_id'], record['subject_id'], record['sample_id'], record['blob'],))
    conn.commit()

    cur.executescript("""
    CREATE UNIQUE INDEX IF NOT EXISTS terra_details_idx ON terra_details(workspace_id, subject_id, sample_id, blob);
    """)
    conn.commit()

    logging.info(f"created index {drs_output_path}")

    #
    # reconcile with gen3
    #

    sql = """

    -- missing sequencing
    drop table if exists flattened ;
    create table flattened
    as
    select
        json_extract(su.json, '$.object.project_id') as "project_id",
        json_extract(su.json, '$.object.anvil_project_id') as "anvil_project_id",
        su.name as "subject_type",
        su.key as "subject_id",
        json_extract(su.json, '$.object.participant_id') as "participant_id",
        json_extract(su.json, '$.object.submitter_id') as "subject_submitter_id",
        sa.name as "sample_type",
        sa.key  as "sample_id",
        json_extract(sa.json, '$.object.sample_id') as "sample_sample_id",
        json_extract(sa.json, '$.object.submitter_id') as "sample_submitter_id",
        json_extract(sa.json, '$.object.specimen_id') as "sample_specimen_id",
        'sequencing' as "sequencing_type",
        sequencing_edge.src  as "sequencing_id",
        json_extract(sq.json, '$.object.submitter_id') as "sequencing_submitter_id",
        json_extract(sq.json, '$.object.ga4gh_drs_uri') as "ga4gh_drs_uri"
        from vertices as su
            join edges as sample_edge on sample_edge.dst = su.key and sample_edge.src_name = 'sample'
                join vertices as sa on sample_edge.src = sa.key
                    left join edges as sequencing_edge on sequencing_edge.dst = sa.key and sequencing_edge.src_name = 'sequencing'
                        join vertices as sq on sequencing_edge.src = sq.key

        where
        su.name = 'subject'            ;


    drop table if exists summary ;
    create table summary
    as
    select f.project_id, f.anvil_project_id,
        count(distinct f.subject_id) as "subject_count",
        count(distinct f.sample_id) as "sample_count",
        count(distinct m.sequencing_id) as "sequencing_count",
        count(distinct m.ga4gh_drs_uri) as "ga4gh_drs_uri_count"
        from flattened as f
            left join flattened as m on f.project_id = m.project_id and f.anvil_project_id = m.anvil_project_id
        group by f.project_id, f.anvil_project_id;


    drop table if exists reconcile_counts;
    create table reconcile_counts as
    select w.workspace_id,
        count(distinct w.sample_id) as "terra_sample_id_count",
        count(distinct f.sample_submitter_id) as "gen3_sample_id_count",
        count(distinct w.blob) as "terra_blob_count",
        count(distinct f.ga4gh_drs_uri) as "gen3_drs_uri_count"
        from terra_details as w
            left join flattened as f on (w.sample_id || '_sample' = f.sample_submitter_id)
    group by w.workspace_id
    having gen3_sample_id_count > 0
    UNION
    select w.workspace_id,
        count(distinct w.sample_id) as "terra_sample_id_count",
        count(distinct f.sample_submitter_id) as "gen3_sample_id_count",
        count(distinct w.blob) as "terra_blob_count",
        count(distinct f.ga4gh_drs_uri) as "gen3_drs_uri_count"
        from terra_details as w
            left join flattened as f on (w.sample_id   = f.sample_submitter_id)
    group by w.workspace_id
    having gen3_sample_id_count > 0
    UNION
    select w.workspace_id,
        count(distinct w.sample_id) as "terra_sample_id_count",
        count(distinct f.sample_submitter_id) as "gen3_sample_id_count",
        count(distinct w.blob) as "terra_blob_count",
        count(distinct f.ga4gh_drs_uri) as "gen3_drs_uri_count"
        from terra_details as w
            left join flattened as f on (w.sample_id   = f.sample_specimen_id)
    group by w.workspace_id
    having gen3_sample_id_count > 0
    ;

    insert into reconcile_counts
    select w.workspace_id,
        count(distinct w.sample_id) as "terra_sample_id_count",
        0 as "gen3_sample_id_count",
        count(distinct w.blob) as "terra_blob_count",
        0 as "gen3_drs_uri_count"
    from terra_details  as w
    where workspace_id not in ( select distinct workspace_id from reconcile_counts )
    group by w.workspace_id    ;
    ;

    drop table if exists missing_sequencing;

    create table missing_sequencing
    as
    select s.key, s.submitter_id  from vertices  as s
    where s.name = 'sample'
    and
    not EXISTS(
        select q.src from edges as q where q.dst = s.key
    ) ;

    drop table if exists subjects_missing_sequencing;
    create table subjects_missing_sequencing
    as
    select s.key, s.submitter_id  from vertices  as s
    where s.name = 'subject'
    and s.key in
    (
        select q.dst from edges as q where q.src in (select ms.key from missing_sequencing as ms)
    ) ;


    """

    logging.info(f"flattening and querying table {drs_output_path}")
    cur.executescript(sql)
    conn.commit()

    logging.info("loaded table")

    conn = sqlite3.connect(drs_output_path)
    cur = conn.cursor()

    df = pd.read_sql_query("SELECT * from summary where anvil_project_id is null;", conn)
    print("# PFB contains gen3 projects without anvil(terra) project", file=report_file)
    print(tabulate(df, headers='keys', tablefmt='github'), file=report_file)

    df = pd.read_sql_query("SELECT * from reconcile_counts where gen3_sample_id_count = 0;", conn)
    print("# Not all terra projects found in Gen3", file=report_file)
    print(tabulate(df, headers='keys', tablefmt='github'), file=report_file)

    df = pd.read_sql_query("SELECT * from reconcile_counts where gen3_sample_id_count > 0 and gen3_sample_id_count <> terra_sample_id_count;", conn)
    print("# Terra / Gen3 samples count mismatch", file=report_file)
    print(tabulate(df, headers='keys', tablefmt='github'), file=report_file)

    df = pd.read_sql_query("SELECT * from reconcile_counts where terra_sample_id_count = gen3_sample_id_count and terra_blob_count = gen3_drs_uri_count;", conn)
    print("# Terra / Gen3 blob/drs count alignment", file=report_file)
    print(tabulate(df, headers='keys', tablefmt='github'), file=report_file)

    df = pd.read_sql_query("SELECT * from reconcile_counts where terra_sample_id_count = gen3_sample_id_count and terra_blob_count <> gen3_drs_uri_count;", conn)
    print("# Terra / Gen3 blob/drs count mismatch", file=report_file)
    print(tabulate(df, headers='keys', tablefmt='github'), file=report_file)

    report_file.close()


if __name__ == '__main__':
    cli()
