#!/usr/bin/env python3
import argparse
import math

import pandas as pd
import igraph
from igraph import Graph

import coincidencetest
from coincidencetest import find_concepts
from coincidencetest import coincidencetest

def gather_data(data, closed_sets, dual_sets):
    sums = data.apply(lambda row : sum(row), axis=0)
    gathered = {
        'number of samples' : data.shape[0],
        'number of features' : data.shape[1],
        'frequencies' : dict(sums),
        'groups' : [{
                'signature' : closed_sets[i],
                'number of samples' : len(dual_sets[i]),
            } for i in range(len(closed_sets)) if len(closed_sets[i]) > 1
        ],
    }
    return gathered

def do_tests(gathered):
    frequencies = gathered['frequencies']
    rows = []
    original_sets = {}
    for i, item in enumerate(gathered['groups']):
        p_value = coincidencetest(
            item['number of samples'],
            [frequencies[feature] for feature in item['signature']],
            gathered['number of samples'],
            format_p_value=False,
            correction_feature_set_size=gathered['number of features'],
        )
        s = ' '.join(sorted(item['signature']))
        rows.append({
            'Signature' : s,
            'Frequency' : item['number of samples'],
            'Out of' : gathered['number of samples'],
            'p-value' : p_value,
        })
        original_sets[s] = item['signature']
        percent = round(100 * (i / len(gathered['groups'])))
        message = ('Testing signatures: %s %s' % ('(' + str(percent) + '%)  ', ' '.join(sorted(item['signature']))))
        print(message + '      ', end='\r')
    print('')
    table = pd.DataFrame(rows)
    table.sort_values(by='p-value', inplace=True)
    return table, original_sets

def list_frequencies(gathered):
    number_of_samples = gathered['number of samples']
    frequencies = gathered['frequencies']
    frequencies_table = pd.DataFrame([
        (phenotype, frequency) for phenotype, frequency in frequencies.items()
    ], columns=['Phenotype', 'Frequency (out of %s)' % number_of_samples])
    return frequencies_table

def write_lattice(results_table, original_sets):
    node_attributes = {
        'Frequency' : list(results_table['Frequency']),
        'log10 p' : [math.log10(p) for p in results_table['p-value']],
        'Signature' : list(results_table['Signature']),
    }
    g = Graph(directed=True)
    g.add_vertices(len(node_attributes['Frequency']), attributes=node_attributes)

    all_signatures = list(results_table['Signature'])
    print(all_signatures)
    for i, row1 in results_table.iterrows():
        for j, row2 in results_table.iterrows():
            if i == j:
                continue
            I = row1['Signature']
            J = row2['Signature']
            if len(set(original_sets[I]).difference(original_sets[J])) == 0:
                edge = (all_signatures.index(I), all_signatures.index(J))
                print('Added %s' % str(edge))
                g.add_edges([edge])

    g.save('coincidence_clustering_lattice.graphml', format='graphml')

if __name__=='__main__':
    parser = argparse.ArgumentParser(
        description = ''.join([
            'This program computes "formal concepts" (maximal biclusters) in ',
            'binary feature data, assesses them using the exact test for ',
            'coincidence, and reports the results.',
        ])
    )
    parser.add_argument(
        '--input-filename',
        dest='input_filename',
        type=str,
        required=True,
        help='CSV or TSV file containing a binary matrix with row and column names included.',
    )
    parser.add_argument(
        '--delimiter',
        type=str,
        default='tab',
        help='Either a comma (the character ,) or the word tab. Default is tab.',
    )
    parser.add_argument(
        '--level-limit',
        dest='level_limit',
        type=int,
        default=None,
        help='Per-level limit on considered pairs for bicluster discovery algorithm. See documentation for ConceptLattice.',
    )
    parser.add_argument(
        '--max-recursion',
        dest='max_recursion',
        type=int,
        default=None,
        help='Limit on number of recursion levels for bicluster discovery algorithm. See documentation for ConceptLattice.',
    )
    parser.add_argument(
        '--output-tsv',
        dest='output_tsv',
        type=str,
        required=True,
        help='If provided, used as output filename for table of discovered signatures.'
    )
    args = parser.parse_args()

    delimiter = args.delimiter
    if delimiter == 'tab':
        delimiter = '\t'
    data = pd.read_csv(args.input_filename, delimiter=delimiter)
    closed_sets, dual_sets = find_concepts(data, level_limit=args.level_limit, max_recursion=args.max_recursion)
    gathered = gather_data(data, closed_sets, dual_sets)
    
    table, original_sets = do_tests(gathered)
    table.to_csv(args.output_tsv, sep='\t', index=False)

    frequencies_table = list_frequencies(gathered)
    frequencies_table.to_csv('coincidence_clustering_frequencies.tsv', sep='\t', index=False)

    write_lattice(table, original_sets)
