#!/usr/bin/env python3
#
# Calcurate statistics of values in each column
# Copyright (c) 2018-2021, Hiroyuki Ohsaki.
# All rights reserved.
#

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

import fileinput
import math
import re
import statistics
import sys

from perlcompat import die, warn, getopts
#import tbdump

IS_MULTI = {'covar', 'corr', 'spearmn', 'kendall', 'rise', 'settling'}

STABILITY_TEST_FRACTION = 0.10
STABILITY_THRESHOLD = 0.05
RISE_TIME_THRESHOLD = 0.2
SETTLING_TIME_THRESHOLD = 0.2

def usage():
    die(f"""\
usage: {sys.argv[0]} [-va] [-t type[,type...]] [file...]
  -v                 display labels for each statistic
  -a                 display all statistics
  -t type[,type...]  specify statistics to display""")

def sign(val):
    if val < 0:
        return -1
    elif val > 0:
        return 1
    return 0

def count(vals):
    return len(vals)

def mean(vals):
    return statistics.mean(vals)

def median(vals):
    return statistics.median(vals)

def mode(vals):
    counts = {}
    for v in vals:
        if counts.get(v, None):
            counts[v] += 1
        else:
            counts[v] = 1
    lst = sorted(vals, key=lambda v: counts[v])
    return lst[-1]

def moment(n, vals):
    m = mean(vals)
    lst = [(v - m)**n for v in vals]
    return mean(lst)

def variance(vals):
    return moment(2, vals)

def stddev(vals):
    return math.sqrt(variance(vals))

def skewness(vals):
    return moment(3, vals) / moment(2, vals)**1.5

def kurtosis(vals):
    return moment(4, vals) / moment(2, vals)**2 - 3

def cv(vals):
    m = mean(vals)
    if m:
        return 100 * stddev(vals) / m
    else:
        return 0

# return the P-percentile
def percentile(p, vals):
    pos = int(len(vals) * p)
    lst = sorted(vals)
    return lst[pos]

# return the LEVEL% confidence interval
def conf_interval(level, vals):
    zval = {90: 1.645, 95: 1.960, 99: 2.576, 99.9: 3.29}
    if not zval.get(level, None):
        return None
    if count(vals) == 0:
        return None
    return zval[level] * stddev(vals) / math.sqrt(count(vals))

def covariance(xvals, yvals):
    meanx = mean(xvals)
    meany = mean(yvals)
    lst = [(x - meanx) * (y - meany) for x, y in zip(xvals, yvals)]
    return mean(lst)

def correlation(x, y):
    return covariance(x, y) / math.sqrt(variance(x)) / math.sqrt(variance(y))

def rank(vals):
    lst = sorted(vals)
    rank_by_val = {}
    for i, v in enumerate(lst):
        rank_by_val[v] = i + 1
    return [rank_by_val[v] for v in vals]

def rank_correlation(x, y):
    x_rank = rank(x)
    y_rank = rank(y)
    return correlation(x_rank, y_rank)

def kendall_tau(x, y):
    x_rank = rank(x)
    y_rank = rank(y)
    ncordant = 0
    for i in len(x_rank):
        for j in len(y_rank):
            if sign(x_rank[j] - x_rank[i]) == sign(y_rank[j] - y_rank[i]):
                ncordant += 1
            else:
                ncordant -= 1
    n = count(x)
    return ncordant / (n * (n - 1) / 2)

def fairness_index(vals):
    n = count(vals)
    lst = [v**2 for v in vals]
    return sum(vals)**2 / (n * sum(lst))

def _extract_tail(vals):
    pos = int(len(vals) * (1 - STABILITY_TEST_FRACTION))
    tail = vals[pos:]
    avg = mean(tail)
    return tail, avg

def is_stable(vals):
    tail, avg = _extract_tail(vals)
    minv = min(tail)
    maxv = max(tail)
    # FIXME: should note assume positive values
    return (maxv - avg) <= STABILITY_THRESHOLD * avg \
        and (avg - minv) <= STABILITY_THRESHOLD * avg

def overshoot(vals):
    if not is_stable(vals):
        return None
    tail, avg = _extract_tail(vals)
    # FIXME: must check if if converged level is zero
    return max(tail) / avg

def rise_time(times, vals):
    raise
    if not is_stable(vals):
        return None
    tail, avg = _extract_tail(vals)
    for t, v in zip(times, vals):
        if abs(v - avg) / avg <= RISE_TIME_THRESHOLD:
            return t
    return None

def settling_time(times, vals):
    if not is_stable(vals):
        return None
    tail, avg = _extract_tail(vals)
    for t, v in reversed(zip(times, vals)):
        if abs(v - avg) / avg >= SETTLING_TIME_THRESHOLD:
            return t
    return None

#----------------------------------------------------------------
STAT_TBL = {
    'count': count,
    'sum': sum,
    'total': sum,  # backward compatibility
    'min': min,
    'max': max,
    'mean': mean,
    'median': median,
    'mode': mode,
    'var': variance,
    'stddev': stddev,
    'skew': skewness,
    'kurt': kurtosis,
    'cv': cv,
    'quant10': lambda x: percentile(0.10, x),
    'quant25': lambda x: percentile(0.25, x),
    'quant50': lambda x: percentile(0.50, x),
    'quant75': lambda x: percentile(0.75, x),
    'quant90': lambda x: percentile(0.90, x),
    'conf90': lambda x: conf_interval(90, x),
    'conf95': lambda x: conf_interval(95, x),
    'conf99': lambda x: conf_interval(99, x),
    'covar': covariance,
    'corr': correlation,
    'spearmn': rank_correlation,
    'kendall': kendall_tau,
    'findex': fairness_index,
    'stable': is_stable,
    'overshoot': overshoot,
    'rise': rise_time,
    'settling': settling_time,
}

def format_number(n):
    if n is None:
        return 'NaN'
    return f'{n:g}'

def main():
    opt = getopts('vat:') or usage()

    stat_types = 'count sum min max mean median mode var stddev cv skew kurt ' \
    'quant10 quant25 quant50 quant75 quant90 conf90 conf95 conf99 findex'.split(
    )
    if opt.t:
        stat_types = opt.t.split(',')

    if opt.a:
        stat_types = sorted(STAT_TBL.keys())

    # load each column into LIST
    ncols = None
    data = []
    for line in fileinput.input():
        line = line.rstrip()
        if not re.match(r'\s*([\d.+-]|NaN)', line, re.IGNORECASE):
            continue

        fields = line.split()
        if not ncols:
            ncols = len(fields)
            data = [[] for col in range(ncols)]

        for col, val in enumerate(fields):
            if val == 'NaN':
                data[col] = []
            else:
                data[col].append(float(val))

    # display statistics for each column
    for atype in stat_types:
        stats = []
        func = STAT_TBL.get(atype, None)
        if not func:
            print(f"Unsupported statistic `{atype}'")
            exit()
        if atype in IS_MULTI:
            for col in range(ncols - 1):
                stats.append(func(data[col], data[col + 1]))
        else:
            stats = [func(vals) for vals in data]
        if not stats:
            continue
        if opt.v:
            print(f'{atype}\t', end='')
        print('\t'.join([format_number(v) for v in stats]))

if __name__ == "__main__":
    main()
