#!/usr/bin/env python3

"""Reconcile and aggregate results."""

import sqlite3
from gen3.auth import Gen3Auth
from gen3.submission import Gen3Submission
from anvil.clients.gen3_auth import Gen3TerraAuth
from anvil.clients.gen3_auth import TERRA_TOKEN_URL
from gen3.query import Gen3Query

import os
import shutil
import logging
import json

import click

from anvil.util.reconciler import DEFAULT_CONSORTIUMS, DEFAULT_OUTPUT_PATH, DEFAULT_NAMESPACE, aggregate
from anvil.terra.reconciler import Entities


from datetime import date, datetime
from anvil.util.reconciler import flatten
import pandas as pd
from tabulate import tabulate
import anvil.util.data_ingestion_tracker

logging.basicConfig(level=logging.INFO, format='%(asctime)s %(filename)s %(levelname)-8s %(message)s')
logger = logging.getLogger(__name__)

DEFAULT_GEN3_CREDENTIALS_PATH = os.path.expanduser('~/.gen3/credentials.json')
DEFAULT_OUTPUT_PATH = os.environ.get('OUTPUT_PATH',DEFAULT_OUTPUT_PATH)
DEFAULT_SPREADSHEET_PATH = f"{DEFAULT_OUTPUT_PATH}/spreadsheet.json"

@click.group()
@click.pass_context
def cli(ctx):
    """Set up context, main entrypoint."""
    # ensure that ctx.obj exists and is a dict
    # in case we want to eventually chain these commands together into a pipeline
    ctx.ensure_object(dict)


@cli.command('data_ingestion_tracker')
@click.option('--spreadsheet_path', default=DEFAULT_SPREADSHEET_PATH, show_default=True, help="Where to read/write spreadsheet.")
def data_ingestion_tracker(spreadsheet_path):    
    """Read spreadsheet, write to json file."""
    anvil.util.data_ingestion_tracker.data_ingestion_tracker(spreadsheet_path)


@cli.command('terms_lookup')
@click.option('--spreadsheet_path', default=DEFAULT_SPREADSHEET_PATH, show_default=True, help="Where to read/write spreadsheet.")
@click.option('--api_key', default=os.environ.get('BIOONTOLOGY_API_KEY', None), show_default=True, help="bioontology API key.  See https://bioportal.bioontology.org/help#Getting_an_API_key")
def terms_lookup(spreadsheet_path, api_key):    
    """Read spreadsheet, ensure that indicators exist in disease normalizer."""
    assert api_key, "Please provide an API key.  See https://bioportal.bioontology.org/help#Getting_an_API_key"

    from anvil.util.bioontology_lookup import lookup_term
    from anvil.transformers.fhir.disease_normalizer import text_ontology

    projects = [project for project in json.load(open(spreadsheet_path))]
    indications = set([project['indication'] for project in projects if project['indication']])
    logger.debug(f"{len(indications)} unique indictors")
    indications = [indication for indication in indications if indication not in text_ontology ]
    logger.debug(f"{len(indications)} unique indictors remaining after text_ontology lookup")

    term_ids = {}
    for term in indications:
        term_ids[term] = None
        for ontology, term_class  in lookup_term(term, api_key):
            term_ids[term] = [ontology, term_class['@id']]
            break
    if len(term_ids) > 0:        
        logger.error(f"Please add the following to  anvil.transformers.fhir.disease_normalizer {term_ids}")
    else:
        logger.info(f"OK, updating spreadsheet {spreadsheet_path} with diseaseOntologyId")
        for project in projects:
            project['diseaseOntologyId'] = text_ontology.get(project['indication'], None)
        json.dump(projects, open(spreadsheet_path, 'w'))


@cli.command('clean')
@click.option('--output_path', default=DEFAULT_OUTPUT_PATH, help=f'output path default={DEFAULT_OUTPUT_PATH}')
def cleaner(output_path):
    """Remove sqlite work databases, json files, etc."""
    def delete_file(db_file):
        """Remove files."""
        try:
            os.unlink(db_file)
            logging.getLogger(__name__).info(f"Dropped {db_file}")
        except FileNotFoundError:
            logging.getLogger(__name__).warning(f"{db_file} FileNotFound")
        except Exception as e:
            logging.getLogger(__name__).error(f"{db_file} {e}")

    def delete_dir(dir_path):
        """Remove dir."""
        try:
            shutil.rmtree(dir_path)
            logging.getLogger(__name__).info(f"Dropped {dir_path}")
        except FileNotFoundError:
            logging.getLogger(__name__).warning(f"{dir_path} FileNotFound")
        except Exception as e:
            logging.getLogger(__name__).error(f"{dir_path} {e}")

    delete_file(f'{output_path}/data_dashboard.json')
    delete_file(f'{output_path}/data_dashboard.tsv')
    delete_file(f'{output_path}/gen3-drs.sqlite')
    delete_file(f'{output_path}/spreadsheet.json')
    delete_file(f'{output_path}/terra.sqlite')
    delete_file(f'{output_path}/terra-graph.sqlite')
    delete_file(f'{output_path}/pyanvil-cache.sqlite')
    delete_file(f'{output_path}/qa-report.md')
    delete_file(f'{output_path}/terra_summary.json')
    delete_file(f'{output_path}/drs_file.sqlite')
    # TODO refactor cache to configure path
    delete_file('/tmp/pyanvil-cache.sqlite')
    for consortium in DEFAULT_CONSORTIUMS:
        delete_dir(f'{output_path}/{consortium[0]}')


@cli.command('drs-extract')
@click.option('--output_path', default=DEFAULT_OUTPUT_PATH, help=f'output path default={DEFAULT_OUTPUT_PATH}')
@click.option('--gen3_credentials_path', default=DEFAULT_GEN3_CREDENTIALS_PATH, help=f'gen3 native credentials={DEFAULT_GEN3_CREDENTIALS_PATH}')
@click.option('--use_terra_credentials', is_flag=True, default=True, help='Running in terra VM, use terra authenticator to access gen3.')
@click.option('--expected_row_count', default=175000, help="Minimum number of file records expected.")
def drs_extractor(gen3_credentials_path, output_path, use_terra_credentials, expected_row_count):
    """Retrieve DRS url for all gen3 files."""
    gen3_endpoint = "https://gen3.theanvil.io"

    # Install n API Key downloaded from the
    # commons' "Profile" page at ~/.gen3/credentials.json
    
    if use_terra_credentials:
        auth = Gen3TerraAuth(endpoint=gen3_endpoint, terra_auth_url=TERRA_TOKEN_URL, user_email=None)
    else:
        auth = Gen3Auth(endpoint=gen3_endpoint, refresh_file=gen3_credentials_path)


    logger = logging.getLogger(__name__)

    query_client = Gen3Query(auth)
    logger.info('Starting export of data from gen3.')
    raw_data = query_client.raw_data_download(data_type='file', fields='node_id,project_id,anvil_project_id,subject_submitter_id,sample_submitter_id,sequencing_assay_submitter_id,file_name,file_size,md5sum,submitter_id,md5sum,drs_id,_subject_id'.split(','))

    assert len(raw_data) > expected_row_count, f"Expected over {expected_row_count} file records, got {len(raw_data)} instead.  Projects {set(sorted([(r['project_id']) for r in raw_data]))}"

#     anvil_project_ids = set(sorted([','.join(r['anvil_project_id']) for r in raw_data]))
#     gen3_project_ids = set(sorted([(r['project_id']) for r in raw_data]))

    logger.info(f"retrieved {len(raw_data)} file records from gen3.")

    sqlite_path = f'{output_path}/drs_file.sqlite'
    _conn = sqlite3.connect(sqlite_path)
    cur = _conn.cursor()
    cur.executescript("""
    CREATE TABLE IF NOT EXISTS drs_file (
        md5sum text,
        sequencing_id text PRIMARY KEY,
        file_name text,
        ga4gh_drs_uri text,
        sample_submitter_id text,
        subject_submitter_id text,
        subject_id text,
        project_id text,
        anvil_project_id text
    );
    """)
    _conn.commit()
    # optimize for single thread speed
    _conn.execute('PRAGMA synchronous = OFF')
    _conn.execute('PRAGMA journal_mode = OFF')
    _conn.commit()
    _conn.close()

    _conn = sqlite3.connect(sqlite_path, check_same_thread=False, isolation_level='DEFERRED')
    cur = _conn.cursor()
    logger.info(f'Starting import of data into sqlite {sqlite_path}.')

    def _first(_array):
        """Return first element in array, or none if empty."""
        if not _array or len(_array) == 0:
            return None
        return _array[0]


    commit_threshold = 1000
    c = 0
    for row in raw_data:
        try:
            cur.execute(
                "INSERT into drs_file values (?, ?, ?, ?, ?, ?, ?, ?, ?);",

                (    row['md5sum'],
                    row['node_id'],
                    row['file_name'],
                    row['drs_id'],
                    _first(row['sample_submitter_id']),
                    _first(row['subject_submitter_id']),
                    _first(row['_subject_id']),
                    row['project_id'],
                    _first(row['anvil_project_id']),
                )

            )
        except Exception as e:
            logger.error(row)
            logger.error("Fatal error writing row above to sqlite.", exc_info=True)
            break
        c += 1
        if c > commit_threshold:
            _conn.commit()
            c = 0
    _conn.commit()

    logger.info('Indexing')
    cur.executescript("""
    CREATE INDEX IF NOT EXISTS drs_file_md5sum ON drs_file(md5sum);
    CREATE  INDEX IF NOT EXISTS drs_file_file_name ON drs_file(file_name);
    """)
    _conn.commit()
    logger.info(f'Created {sqlite_path}')
    
    def _dict_factory(cursor, row):
        d = {}
        for idx, col in enumerate(cursor.description):
            d[col[0]] = row[idx]
        return d



    _conn = sqlite3.connect(sqlite_path)
    _conn.row_factory = _dict_factory
    cur = _conn.cursor()

    dataset = cur.execute('SELECT project_id as "gen3_project_id", anvil_project_id, count(*) as "file_count" FROM drs_file group by  project_id, anvil_project_id').fetchall()

    header = dataset[0].keys()
    rows =  [x.values() for x in dataset]
    logger.info(f"\nExtracted File Counts\n{tabulate(rows, header)}")
    

@cli.command('extract')
@click.option('--user_project', default=os.environ.get('GOOGLE_PROJECT', None), help=f'Google billing project. default={os.environ.get("GOOGLE_PROJECT", None)}')
@click.option('--namespace', default=DEFAULT_NAMESPACE, help=f'Terra namespace default={DEFAULT_NAMESPACE}')
@click.option('--consortiums', type=(str, str), default=DEFAULT_CONSORTIUMS, multiple=True, help=f'<Name Regexp> e.g "CCDG AnVIL_CCDG.*" default {DEFAULT_CONSORTIUMS}')
@click.option('--output_path', default=DEFAULT_OUTPUT_PATH, help=f'output path default={DEFAULT_OUTPUT_PATH}')
def extractor(user_project, namespace, consortiums, output_path):
    """Harvest all workspaces, return list of workspace_name. Create detailed sqlite graph and summary dashboard."""
    logging.getLogger(__name__).info("Starting aggregation for all specified AnVIL workspaces this will take several minutes.")
    logging.getLogger(__name__).info(f"Reading from consortiums {consortiums}")
    logging.getLogger(__name__).info(f"Writing to output_path {output_path}")
    assert user_project, "Please provide --user_project (or set GOOGLE_PROJECT)"
    dashboard_output_path = f"{output_path}/data_dashboard.json"
    terra_output_path = f"{output_path}/terra.sqlite"
    drs_file_path = f"{output_path}/drs_file.sqlite"
    spreadsheet_path = f"{output_path}/spreadsheet.json"

    with open(dashboard_output_path, 'w') as outs:
        views = [v for v in aggregate(namespace=DEFAULT_NAMESPACE,
                                      user_project=user_project,
                                      consortium=consortiums, drs_file_path=drs_file_path,
                                      terra_output_path=terra_output_path,
                                      data_ingestion_tracker=spreadsheet_path)]
        json.dump({
            'projects': [v for v in views if 'problems' in v],
            'consortiums': [v for v in views if 'problems' not in v]
        }, outs)

    assert len([v for v in views if 'problems' in v]) > 0, f"{consortiums} matched no workspaces"
    assert os.path.isfile(dashboard_output_path), f"{dashboard_output_path} should exist."
    logging.getLogger(__name__).info(f"Wrote summary to {dashboard_output_path}")

    assert os.path.isfile(terra_output_path), f"{terra_output_path} should exist."
    entities = Entities(terra_output_path=terra_output_path, user_project=user_project)
    entities.index()
    # print([workspace.name for workspace in entities.get_by_name('workspace')])
    logging.getLogger(__name__).info(f"Wrote work database to {terra_output_path}")


@cli.command('report')
@click.option('--output_path', default=DEFAULT_OUTPUT_PATH, show_default=True, help='output path.')
@click.option('--user_project', default=os.environ.get('GOOGLE_PROJECT', None), show_default=True, help='Google billing project.')
def reporter(output_path, user_project):
    """Reconcile and report on harvested workspaces."""
    terra_output_path = f"{output_path}/terra.sqlite"
    dashboard_output_path = f"{output_path}/data_dashboard.json"
    drs_output_path = f"{output_path}/drs_file.sqlite"

    logging.getLogger(__name__).info("Starting reporting for all extracted AnVIL workspaces this will take several minutes.")
    entities = Entities(terra_output_path=terra_output_path, user_project=user_project)
    workspace_names = [workspace.name for workspace in entities.get_by_name('workspace')]
    logging.getLogger(__name__).info(f"Reporting on {len(workspace_names)} workspaces")


    def json_serial(obj):
        """JSON serializer for objects not serializable by default json code."""
        if isinstance(obj, (datetime, date)):
            return obj.isoformat()
        raise TypeError("Type %s not serializable" % type(obj))

    logging.getLogger(__name__).info(f"Writing report to {output_path}/qa-report.md")
    report_file = open(f'{output_path}/qa-report.md', 'w')

    # validate output summary and
    assert os.path.isfile(dashboard_output_path), f"dashboard should exist {dashboard_output_path}"
    with open(dashboard_output_path, 'r') as inputs:
        dashboard_data = json.load(inputs)
        assert len(dashboard_data['projects']) > 0, f"dashboard_data['projects'] was empty? {dashboard_output_path}"

    # Flatten dashboard into tsv

    (flattened, column_names) = flatten(dashboard_data['projects'])
    df = pd.DataFrame(flattened)
    df.columns = column_names
    # Print the data  (all rows, all columns)
    pd.set_option('display.max_rows', None)
    pd.set_option('display.max_columns', None)
    # export create a tsv from dataframe
    df.to_csv(f"{output_path}/data_dashboard.tsv", sep="\t")
    logging.getLogger(__name__).info(f"Wrote {output_path}/data_dashboard.tsv")

    print("\n# Dashboard\n", file=report_file)
    print(tabulate(df, headers='keys', tablefmt='github'), file=report_file)

    """
    ## summarize terra exceptions
    > Extract the list of data transformation problems encountered [see more on dashboard exceptions](https://github.com/anvilproject/client-apis/wiki/dashboard-exceptions)
    """
    _projects = [project for project in dashboard_data['projects'] if 'problems' in project]
    flattened = []
    problems = set([problem for project in _projects for problem in project['problems']])
    for problem in problems:
        projects = [project['project_id'] for project in _projects if problem in project['problems']]
        flattened.append([problem, ','.join(projects)])

    print("\n# Exceptions\n> Extract the list of data transformation problems encountered [see more on dashboard exceptions](https://github.com/anvilproject/client-apis/wiki/dashboard-exceptions)\n", file=report_file)
    if len(flattened) > 0:
        # Print the data  (all rows, all columns)

        df = pd.DataFrame(flattened)
        df.columns = ['problem', 'affected_workspaces']
        print(tabulate(df, headers='keys', tablefmt='github'), file=report_file)
    else:
        print("No workspaces have exceptions!", file=report_file)

    print("# Consistent workspaces", file=report_file)
    # list consistent workspaces
    if len([project['project_id'] for project in _projects if len(project['problems']) == 0]) == 0:
        print("None", file=report_file)
    else:
        df = pd.DataFrame([project['project_id'] for project in _projects if len(project['problems']) == 0])
        df.columns = ['workspace']
        print(tabulate(df, headers='keys', tablefmt='github'), file=report_file)

    # Issues/Questions arising from Gen3 PFB
    # create
    def summarize_workspaces():
        """Aggregate harvested workspaces."""
        entities = Entities(terra_output_path=terra_output_path, user_project=user_project)
        # created sql indices
        entities.index()
        terra_summary = f"{output_path}/terra_summary.json"
        emitter = open(terra_summary, "w")
        for workspace in entities.get_by_name('workspace'):
            for subject in workspace.subjects:
                for sample in subject.samples:
                    for property, blob in sample.blobs.items():
                        json.dump(
                            {
                                "workspace_id": workspace.id,
                                "subject_id": subject.id,
                                "sample_id": sample.id,
                                "blob": blob['name'],
                            },
                            emitter,
                            separators=(',', ':')
                        )
                        emitter.write('\n')
        emitter.close()
        logging.getLogger(__name__).info(f"Wrote summary to {terra_summary}")

    summarize_workspaces()

    conn = sqlite3.connect(drs_output_path)
    cur = conn.cursor()

    #
    # load the terra dashboard summary into db
    #
    cur.executescript("""
    --
    drop table if exists terra_details ;
    CREATE TABLE IF NOT EXISTS terra_details (
        workspace_id text,
        subject_id text,
        sample_id text,
        blob text
    );
    """)

    conn.commit()

    logging.info(f"created terra_details {drs_output_path}")
    logging.info(f"loading from  {output_path}/terra_summary.json")
    with open(f"{output_path}/terra_summary.json", 'rb') as fo:
        for line in fo.readlines():
            record = json.loads(line)
            cur.execute("REPLACE into terra_details values (?, ?, ?, ?);", (record['workspace_id'], record['subject_id'], record['sample_id'], record['blob'],))
    conn.commit()

    cur.executescript("""
    CREATE UNIQUE INDEX IF NOT EXISTS terra_details_idx ON terra_details(workspace_id, subject_id, sample_id, blob);
    """)
    conn.commit()

    logging.info(f"created index {drs_output_path}")

    #
    # reconcile with gen3
    #

    sql = """

    drop table if exists summary ;
    create table summary
    as
    select 
        f.project_id, f.anvil_project_id,
        count(distinct f.subject_id) as "subject_count",
        count(distinct f.sample_submitter_id) as "sample_count",
        count(distinct f.sequencing_id) as "sequencing_count",
        count(distinct f.ga4gh_drs_uri) as "ga4gh_drs_uri_count"
    from drs_file as f
    group by f.project_id, f.anvil_project_id;
    """

    logging.info(f"create table summary {drs_output_path}")
    cur.executescript(sql)
    conn.commit()

    sql = """
    drop table if exists reconcile_counts;
    create table reconcile_counts as
    select w.workspace_id,
        count(distinct w.sample_id) as "terra_sample_id_count",
        count(distinct f.sample_submitter_id) as "gen3_sample_id_count",
        count(distinct w.blob) as "terra_blob_count",
        count(distinct f.ga4gh_drs_uri) as "gen3_drs_uri_count"
        from terra_details as w
            left join drs_file as f on (w.sample_id || '_sample' = f.sample_submitter_id)
    group by w.workspace_id
    having gen3_sample_id_count > 0
    UNION
    select w.workspace_id,
        count(distinct w.sample_id) as "terra_sample_id_count",
        count(distinct f.sample_submitter_id) as "gen3_sample_id_count",
        count(distinct w.blob) as "terra_blob_count",
        count(distinct f.ga4gh_drs_uri) as "gen3_drs_uri_count"
        from terra_details as w
            left join drs_file as f on (w.sample_id   = f.sample_submitter_id)
    group by w.workspace_id
    having gen3_sample_id_count > 0
    UNION
    select w.workspace_id,
        count(distinct w.sample_id) as "terra_sample_id_count",
        count(distinct f.sample_submitter_id) as "gen3_sample_id_count",
        count(distinct w.blob) as "terra_blob_count",
        count(distinct f.ga4gh_drs_uri) as "gen3_drs_uri_count"
        from terra_details as w
            left join drs_file as f on (w.sample_id   = f.sample_submitter_id)
    group by w.workspace_id
    having gen3_sample_id_count > 0
    ;

    insert into reconcile_counts
    select w.workspace_id,
        count(distinct w.sample_id) as "terra_sample_id_count",
        0 as "gen3_sample_id_count",
        count(distinct w.blob) as "terra_blob_count",
        0 as "gen3_drs_uri_count"
    from terra_details  as w
    where workspace_id not in ( select distinct workspace_id from reconcile_counts )
    group by w.workspace_id    ;
    ;

    -- drop table if exists missing_sequencing;

    -- create table missing_sequencing
    -- as
    -- select s.key, s.submitter_id  from vertices  as s
    -- where s.name = 'sample'
    -- and
    -- not EXISTS(
        -- select q.src from edges as q where q.dst = s.key
    -- ) ;

    -- drop table if exists subjects_missing_sequencing;
    -- create table subjects_missing_sequencing
    -- as
    -- select s.key, s.submitter_id  from vertices  as s
    -- where s.name = 'subject'
    -- and s.key in
    -- (
        -- select q.dst from edges as q where q.src in (select ms.key from missing_sequencing as ms)
    -- ) ;


    """

    logging.info(f"creating reconciliation tables {drs_output_path}")
    cur.executescript(sql)
    conn.commit()

    logging.info("created all summary tables")

    conn = sqlite3.connect(drs_output_path)
    cur = conn.cursor()

    df = pd.read_sql_query("SELECT * from summary where anvil_project_id is null;", conn)
    print("\n# Gen3 projects without anvil(terra) project\n", file=report_file)
    print(tabulate(df, headers='keys', tablefmt='github'), file=report_file)

    df = pd.read_sql_query("SELECT * from reconcile_counts where gen3_sample_id_count = 0;", conn)
    print("\n# Not all terra projects found in Gen3\n", file=report_file)
    print(tabulate(df, headers='keys', tablefmt='github'), file=report_file)

    df = pd.read_sql_query("SELECT * from reconcile_counts where gen3_sample_id_count > 0 and gen3_sample_id_count <> terra_sample_id_count;", conn)
    print("\n# Terra / Gen3 samples count mismatch\n", file=report_file)
    print(tabulate(df, headers='keys', tablefmt='github'), file=report_file)

    df = pd.read_sql_query("SELECT * from reconcile_counts where terra_sample_id_count = gen3_sample_id_count and terra_blob_count = gen3_drs_uri_count;", conn)
    print("\n# Terra / Gen3 blob/drs count alignment\n", file=report_file)
    print(tabulate(df, headers='keys', tablefmt='github'), file=report_file)

    df = pd.read_sql_query("SELECT * from reconcile_counts where terra_sample_id_count = gen3_sample_id_count and terra_blob_count <> gen3_drs_uri_count;", conn)
    print("\n# Terra / Gen3 blob/drs count mismatch\n", file=report_file)
    print(tabulate(df, headers='keys', tablefmt='github'), file=report_file)

    report_file.close()
    logging.info(f"created report {output_path}/qa-report.md")


if __name__ == '__main__':
    cli()
