#!/usr/bin/env python3
import argparse
import random
import re

import pandas as pd

import coincidencetest
from coincidencetest import find_concepts
from coincidencetest import coincidencetest

def listify(table):
    return [list(row) for i, row in table.iterrows()]

def format_row(i, l, element_formatter=lambda x:x, class_chooser=lambda x:'', columns=None):
    main = ''.join(['<td class="%s">' % class_chooser(element) + element_formatter(element) + '</td>' for element in l])
    if columns is not None:
        main = '<td class="rowname">' + columns[i] + '</td>' + main
    return main

def format_table(array, element_formatter=lambda x:x, class_chooser=lambda x:'', columns=None):
    return '<table><tr>' + '</tr><tr>'.join([format_row(i, row, element_formatter=element_formatter, class_chooser=class_chooser, columns=columns) for i, row in enumerate(array)]) + '</tr></table>'

def report_test_result(table):
    N = table.shape[1]
    intersection = [i for i, row in df.iterrows() if sum(row) == N]
    frequencies = [sum(row) for i, row in table.transpose().iterrows()]
    pvalue = coincidencetest(
        incidence_statistic=len(intersection),
        frequencies=frequencies,
        number_samples=df.shape[0],
    )
    message = '''
    Incidence statistic: %s
    Frequencies:         %s
    Number of samples:   %s
    p-value:             %s
    ''' % (len(intersection), ' '.join(str(f) for f in frequencies), df.shape[0], pvalue)
    print(message)

if __name__=='__main__':
    parser = argparse.ArgumentParser(
        description = ''.join([
            'This program computes "formal concepts" (maximal biclusters) in ',
            'binary feature data, assesses them using the exact test for ',
            'coincidence, and reports the results.',
        ])
    )
    parser.add_argument(
        'input_filename',
        type=str,
        help='CSV or TSV file containing a binary matrix with row and column names included.',
    )
    parser.add_argument(
        '--delimiter',
        dest='delimiter',
        type=str,
        default='tab',
        help='Either a comma (the character ,) or the word tab. Default is tab.',
    )
    args = parser.parse_args()

    separator = '\t' if args.delimiter == 'tab' else ','
    df = pd.read_csv(args.input_filename, sep=separator, index_col=0)

    N = df.shape[1]
    intersection = [i for i, row in df.iterrows() if sum(row) == N]
    rows = [tuple(row) for i, row in df.iterrows()]
    remainder = [i for i, row in df.iterrows() if not i in intersection]
    sort_order = intersection + random.sample(remainder, len(remainder))
    sort_order_random = random.sample(intersection + remainder, df.shape[0])
    dfs = df.loc[sort_order]
    df = df.loc[sort_order_random]

    style = '''
    body {
        font-family: sans-serif;
    }
    table {
        background: white;
        padding: 0px;
        display: inline;
    }
    table tr td{
        width: 1px;
        height: 50px;
    }
    td.one {
        background: #f25b24;
    }
    td.zero {
        background: #e9e3f8;
    }
    td.rowname {
        background: white;
        border: solid white 1px;
        padding: 10px;
        width: auto;
        font-size: 2.2em;
    }
    div.figdiv {
        text-align: center;
    }
    div.tablewrapper {
        display: inline-block;
    }
    '''

    table1 = format_table(
        listify(df.transpose()),
        element_formatter=lambda x: '',
        class_chooser=lambda x: 'one' if x == 1 else 'zero',
        columns = df.columns,
    )
    table2 = format_table(
        listify(dfs.transpose()),
        element_formatter=lambda x: '',
        class_chooser=lambda x: 'one' if x == 1 else 'zero',
        columns = df.columns,
    )

    html = '''
    <html>
    <head>
    <style>
    %s
    </style>
    </head>
    <body>
    <div class="figdiv">
    <div class="tablewrapper">
    %s
    </div>
    <br><br><br>
    <div class="tablewrapper">
    %s
    </div>
    </div>
    </body>
    </html>
    ''' % (style, table1, table2)
    report_test_result(df)
    filename = re.sub(r'(csv|tsv)$', 'html', args.input_filename)
    open(filename, 'wt').write(html)
    print('Wrote ' + filename + '\n')
