#!/usr/bin/env python3
import argparse

import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
pio.kaleido.scope.default_format = "svg"

import coincidencetest
from coincidencetest._coincidencetest import calculate_probability_of_multicoincidence
from coincidencetest import coincidencetest

parser = argparse.ArgumentParser(
    description = 'Shows PDF of intersection statistic.'
)
parser.add_argument(
    'set_sizes',
    type=int,
    nargs='+',
)
parser.add_argument(
    '-n',
    dest='n',
    type=int,
    required=True,
    help='The size of the ambient set.',
)
parser.add_argument(
    '--output-filename',
    dest='output_filename',
    type=str,
    default=None,
    required=False,
    help='Output SVG filename.'
)
args = parser.parse_args()

set_sizes = args.set_sizes
ambient_size = args.n
pdf = [
    calculate_probability_of_multicoincidence(
        intersection_size = intersection_size,
        set_sizes = set_sizes,
        ambient_size = ambient_size,
    ) for intersection_size in range(1, ambient_size + 1)
]
cdf = [
    1 - coincidencetest(
        incidence_statistic = intersection_size,
        frequencies = set_sizes,
        number_samples = ambient_size,
        format_p_value = False,
    ) for intersection_size in range(1, ambient_size + 1)
]
cdf = cdf[1:-1] + [1]
I_values = [intersection_size for intersection_size in range(1, ambient_size + 1)]
ratio_values = [intersection_size / ambient_size for intersection_size in range(1, ambient_size + 1)]

df = pd.DataFrame({'probability' : pdf, 'intersection size' : I_values, 'intersection out of whole' : ratio_values})

fig = px.bar(df, x='intersection size', y='probability')
fig.update_layout(yaxis_range=[0,1])
fig.update_layout(bargap=0.0)
fig.add_trace(go.Scatter(y=cdf, x=I_values, mode="lines", line=dict(color = 'rgb(245, 15, 10)', shape='hvh', width=1)), row=1, col=1)
fig.update_layout(showlegend=False)

if args.output_filename:
    with open(args.output_filename, 'wb') as file:
        file.write(fig.to_image(format="svg", engine="kaleido"))
else:
    fig.show()
