#!/usr/bin/python3

import numpy as np
import seaborn as sns
import matplotlib.pylab as plt

# https://stackoverflow.com/questions/33282368/plotting-a-2d-heatmap-with-matplotlib
def heatmap(mx):
    '''Generate a heatmap from a matrix'''
    ax = sns.heatmap(mx, linewidth=0.5)
    plt.show()

def heatmap2d(arr: np.ndarray):
    plt.imshow(arr, cmap='viridis')
    plt.colorbar()
    plt.show()
    
from ektools import index, parse, error

#########################################
if __name__ == '__main__':

    import argparse
    import sys
    fn = sys.argv[1]

    idx = [(p,t,l,m) for p,t,l,m in index(fn) if t == 'RAW0' ]
    if idx == []:
        error(f'File "{fn}" did not contain any RAW0 datagrams - aborted.') 
    l = len(idx)
    # extract data from all RAW datagrams
    ch = {}
    i = 0
    for (p,t,l,m) in idx:
        print(f'datagram {i}/{l} pos: {p}', end='\r')
        obj = parse(m)
        my_chan = obj['channel']
        if my_chan in ch.keys():
            ch[my_chan].append(obj)
        else:
            ch[my_chan] = [obj]
    print()
    for c in ch.keys():
        print(f'channel={c}, num objects={len(ch[c])} size={len(ch[c][0]["power"])}')
        mx = np.zeros(shape=(len (ch[c][0]['power']), len(ch[c]) ))
        for i, o in enumerate(ch[c]):
            mx[:,i] = o['power']
        heatmap2d(mx)
            
    # TODO:
    # - match timestamps
    # - include frequency
    # - plot angles?
