#! /usr/bin/env python3
"""
Example application to fit data obtained with a CAEN N6725 read out by this software.
In this concrete example, we fit x-ray data for Si(Li) detectors for the GAPS experiment, 
but this program can be easily adapted to another scenario. 
Specific to GAPS is the grouping in detectors with 8 channels each, as well as for this very
specific calibration campaign, calibration data is taken with dedicated electronics, so
the values for the flight electronic will get projected. See the asic_projection part for tath.
"""


import matplotlib
matplotlib.use('Agg')

import uproot as up
import dashi as d
import numpy as np
import hepbasestack as hep
import pylab as p
import HErmes as he
import hepbasestack.layout as lo
import concurrent.futures as fut
import os
import os.path
import re
import time

import tqdm

from dactylos.analysis.waveform_analysis import WaveformAnalysis
from dactylos.analysis.noisemodel import noise_model, fit_noisemodel, extract_parameters_from_noisemodel   
from dactylos.analysis.spectrum_fits import fit_file
from dactylos.analysis.utils import get_stripname,\
                                    plot_waveform,\
                                    get_data_from_mit_file
from dactylos.analysis.asic_projection import enc2_asic, asic_projection_plot
# call imagemapgick with subprocess
# FIXME: there are python bindings

import shlex
import subprocess

from dataclasses import dataclass
from copy import deepcopy as copy

@dataclass
class PeakingTimeRun:
    ptime    : int = -1
    channel  : int = -1
    detid    : int = -1
    fwhm     : float = -1
    fwhm_err : float = -1 
    figname  : matplotlib.figure.Figure = p.figure()   


hep.visual.set_style_present()
d.visual()

logger = hep.logger.get_logger(20)

p.ioff()

########################################################################



if __name__ == '__main__':

    import sys
    import argparse
    from glob import glob

    parser = argparse.ArgumentParser(description='Analyze data taken by N6725 digitizer.\n\
                                     \t\tThis script can do several things:\n\
                                     \t\t1) Analyze the histograms as saved by the MC2 windows\n\
                                     \t\t   software in textfiles (one line per bin)\n\
                                     \t\t2) Analyze dactylos data - root files with the\n\
                                     \t\    tenergy from the digitizer internal shaping algorithm\n\
                                     \t\t3) Analyze the waveform data taken by dactylos - root files\n\
                                     \t\t   with waveform entries')
    parser.add_argument('--mc2',
                        dest='mc2',
                        default=False, action='store_true',\
                        help='Analyze data taken with mc2/windows.\
                              Currently, this relies on a certain file structure.')
    parser.add_argument('--trapezoid',
                        dest='trapezoid',
                        default=False, action='store_true',\
                        help='Use trapezoid for shaping instead of\
                              the Gaussian.'
                       )
    #parser.add_argument('infiles', metavar='infiles', type=str, nargs='+',
    #                help='input root files, taken with CraneLab')
    #parser.add_argument('--debug', dest='debug',
    #                    default=False, action='store_true',
    #                    help='Set the loglevel to 10, that is debug')
    parser.add_argument('-i','--indir', dest='indir',
                        default=".",type=str,
                        help='glob all rootfiles in this directory for input"')
    parser.add_argument('--mit-infile', dest='mit_infile',
                        default=".",type=str,
                        help='Use a single file for input. Typically used for MC2 measurement performed at MIT"')
    parser.add_argument('-o','--outdir', dest='outdir',
                        default="firxraydata-outdir",type=str,
                        help='save plots in a directory named "outdir"')
    parser.add_argument('-j','--jobs', dest='njobs',
                        default=12,type=int,
                        help='number of jobs to use')
    parser.add_argument('--gaussian-shaping-times', nargs='+',\
                        help='Specify a sequence of peking times for the Gaussian\
                              shaper. If not given, a reasonable default is used',\
                        dest='gaussian_shaping_times',\
                        default=[],required=False)
    parser.add_argument('--channels', nargs='+',\
                        help='The sequence of channels to analyze. Default to 0,1,2,3,4,5,6,7',\
                        dest='channels',\
                        default=[0,1,2,3,4,5,6,7],required=False)
    parser.add_argument('--shaper-order', dest='shaper_order',
                        default=4,type=int,
                        help='order parameter of the Gaussian shaper. This parameter only has an effect together with --waveform-analysis.')
    parser.add_argument('--detector-capacitance', dest='detector_capacitance',
                        default=70,type=float,
                        help='Capacitance of the detector. Defaults to pF.\
                              Default should not be changed without good reason.\
                              This is only used for the asic prediction! [GAPS]')
    parser.add_argument('--shaper-different-orders',
                        dest='shaper_different_orders',
                        default=False, action='store_true',\
                        help='Change the order for the shaper from 4 to 7.\
                              It is 4 for pt < 5ms and 7 otherwise.')
    parser.add_argument('--waveform-analysis', dest='waveform_analysis',
                        default=False, action='store_true',
                        help='Use the waveform information to generate the noisemodel plot.')
    parser.add_argument('--plot-waveforms', dest='plot_waveforms',
                        default=False, action='store_true',
                        help='plot the individual waveforms as png')
    parser.add_argument('--combine-images', dest='combine_images',
                        default=False, action='store_true',
                        help='Use imagemagick to combine the images to a single pannel. This might be problemantic if there are too many shapingtimes used.')
    parser.add_argument('--asic-projection', dest='asic_projection',
                        default=False, action='store_true',
                        help='Also create the asic projection plot for the GAPS experiment. One plot per detector')
    #parser.add_argument('--run-sequence', dest='run_sequence',
    #                    default=False, action='store_true',
    #                    help='Only get the digitizer energy from the series of given files and produce the resolution plot. This is meant to be used for one file per shaping time')

    args = parser.parse_args()
    method = 'UNKNOWN' # is either trapezoid or gauss4 
                       # FIXME: What happens for gauss of different orders?
                       # -> For now, we use the gauss4....

    # create the outdir
    try:
        os.makedirs(args.outdir)
    except Exception as e:
        logger.warning(f'Can not create directory {args.outdir}! {e}')

    if args.njobs > 1:
        executor = fut.ProcessPoolExecutor(max_workers=14)

    if args.mc2:
        if args.mit_infile:
            regex = 'Sh0(?P<detid>[0-9]*)_-37C.txt'
            dataset = get_data_from_mit_file(args.mit_infile)
            pattern = re.compile(regex)
        else:
            Binfiles = []
            subfolders = glob(os.path.join(args.indir, '*'))
            for subfolder in subfolders:
                infiles += glob(os.path.join(subfolder, '*.txt'))

            regex = 'pt_(?P<stime>[0-9]*)_(mus|mu_s|ns)\/Sh(?P<detid>[0-9]*)_(?P<strip>[A-H]).txt' 
            pattern = re.compile(regex)
        method = 'trapezoid' # mc2 uses a trapezoid
    else:
        infiles = glob(os.path.join(args.indir, '*.root'))
        regex = '(?P<detid>[0-9]*)_(s|p)time(?P<stime>[0-9]*).root' 
        pattern = re.compile(regex)

    if args.njobs > 1: future_to_ptime = dict()
    
    # store the results
    datasets = []

    # channels to analyze
    active_channels = [int(k) for k in args.channels]

    # mapping detector strips <-> digitizer channels
    ch_name_to_int = {'A' : 0, 'B' : 1, 'C' : 2, 'D' : 3,\
                      'E' : 4, 'F' : 5, 'G' : 6, 'H' :7}

    if args.plot_waveforms and not args.waveform_analysis:
        raise ValueError("Can not plot waveforms without the --waveform-analysis switch!") 
    if args.waveform_analysis:
        logger.info("Will do anlysis directly on the waveforms")
        method = 'gauss'
        if args.trapezoid:
            logger.warning("Will use trapezoid for shaping instead of Gaussian - experimental feature!")
            method = 'trapezoid'

        analysis = WaveformAnalysis(njobs=args.njobs,\
                                    active_channels=active_channels,\
                                    use_simple_trapezoid_shaper=args.trapezoid,\
                                    shaper_order=args.shaper_order,
                                    adjust_shaper_order_dynamically=args.shaper_different_orders)
        analysis.set_infiles(infiles)
        for infile in infiles:
            metadata = pattern.search(infile).groupdict()
            # FIXME: this is buggy in case
            # mutliple detectors are mixed
            ptime = int(metadata['stime']) # shaping/peaking time
            detid = int(metadata['detid']) # detector id

        # in case of the waveform analysis, the actual analysis 
        # is performed here, since we first needed to just
        # get the filenames 
        analysis.read_waveforms()
        nevents = analysis.get_eventcounts()
        recordlengths = analysis.get_recordlengths()
        logger.info(f'Seeing recordlenghts of {recordlengths}')
        logger.info(f'Seeing nevents of {nevents}')
        for k in analysis.active_channels:
            logger.info(f" Record length in ch {k} ; {len(analysis.channel_data[k][0])}")
        
        if args.gaussian_shaping_times:
            peaking_sequence = np.array([int(k) for k in args.gaussian_shaping_times])
        else:
            # default peaking sequence typically used by Mengjiao
            peaking_sequence = np.array([500,1000,2000,4000,5000,8000,12000,16000,20000,30000])    
            #peaking_sequence = np.array([500,759,1000,2000,3000,4000,5000, 7000,8000,10000,12000,14000,16000,18000,20000,25000,30000])    
         
        logger.info(f'Setting peaking sequence {peaking_sequence} nano seconds')
        logger.info(f'Will use a {method} shaper')
        analysis.set_peakingtime_sequence(peaking_sequence)

        for ch in analysis.active_channels:
            if args.plot_waveforms:
                analysis.plot_waveforms(ch, 10)
            channel_peakingtimes = analysis.analyze(ch)

            logger.info('Shaping completed!')
            logger.info('Submitting fitting jobs...')

            for ptime in channel_peakingtimes.keys():
                kwargs = {'energy'     : channel_peakingtimes[ptime],\
                          'channel'    : copy(ch),\
                          'plot_dir'   : copy(args.outdir),\
                          'ptime'      : copy(ptime),\
                          'detid'      : copy(detid),\
                          'file_type'  : None
                }
            
                if args.njobs > 1:
                    future_to_ptime[executor.submit(fit_file, **kwargs)] = ptime 
                else:
                    detid, ch, params, params_err, figname = fit_file(**kwargs)
                    data = PeakingTimeRun(detid = detid,\
                                          channel = ch,\
                                          fwhm = params[1],\
                                          fwhm_err = params_err['fwhm0'],\
                                          ptime = ptime,\
                                          figname = figname)
                    datasets.append(data)
    

    # in case we do the analysis on the 
    # energies as calculated by the digitizer
    # we will have one root file per shaping time
    # this type of analysis is performed here
    # there is one if-branch for linux and 
    # another one for mc2/windows
    for infile in infiles:
        if args.waveform_analysis:
            # we do not need this step here and can 
            # skip it
            break

        elif args.mc2:
            if args.mit_infile:
                detid = pattern.search(args.mit_infile).gropudict()['detid']
                datasets = get_data_from_mit_file(args.mit_infile)
            else:
                logger.info(f"Loading {infile}")
                metadata = pattern.search(infile).groupdict()
                ptime = int(metadata['stime'])
                if ptime < 500:
                    ptime *= 1000 # converti to ns
                detid = int(metadata['detid'])
                strip = metadata['strip'] 
                
                #for ch in ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H']:
                ch = ch_name_to_int[strip]
                kwargs = {'infilename' : copy(infile),\
                          'channel'    : copy(ch),\
                          'plot_dir'   : copy(args.outdir),\
                          'ptime'      : copy(ptime),\
                          'detid'      : copy(detid),\
                          'file_type'  : '.txt'
                }
                
                if args.njobs > 1:
                    future_to_ptime[executor.submit(fit_file, **kwargs)] = ptime 
                else:
                    detid, ch, params, params_err, figname = fit_file(**kwargs)
                    data = PeakingTimeRun(detid = detid,\
                                          channel = ch,\
                                          fwhm = params[1],\
                                          fwhm_err = params_err['fwhm0'],\
                                          ptime = ptime,\
                                          figname = figname)
                    datasets.append(data)
        else:  
            # last option will be anlyzing data taken 
            # with dactylos (root files) one file 
            # per shaping time
            logger.info(f"Loading {infile} .. ")

            metadata = pattern.search(infile).groupdict()
            ptime = int(metadata['stime']) # shaping/peaking time
            detid = int(metadata['detid']) # detector id
            logger.info(f'.. found peaking time of {ptime} for detector id {detid}')
            # anlayze per channel, each file has all
            # the 8 channels saved 
            for ch in active_channels:
                    
                # be paranoid with mutable arguments
                file_type = '.root'
                kwargs = {'infilename' : copy(infile),\
                          'channel'    : copy(ch),\
                          'plot_dir'   : copy(args.outdir),\
                          'ptime'      : copy(ptime),\
                          'detid'      : copy(detid),\
                          'file_type'  : copy(file_type)
                }
                
                if args.njobs > 1:
                    future_to_ptime[executor.submit(fit_file, **kwargs)] = ptime 
                else:
                    detid, ch, params, params_err, figname = fit_file(**kwargs)
                    data = PeakingTimeRun(detid = detid,\
                                          channel = ch,\
                                          fwhm = params[1],\
                                          fwhm_err = params_err['fwhm0'],\
                                          ptime = ptime,\
                                          figname = figname)
                    datasets.append(data)


    # get the results in the parallel case
    if args.njobs > 1:
        for future in tqdm.tqdm(fut.as_completed(future_to_ptime), desc='Calculating fit results..'):
            ptime = future_to_ptime[future]
            try:
                detid, ch, params, params_err, figname = future.result()
            except Exception as e:
                logger.warning(f'Can not fit file for ptime {ptime}, exception {e}')
                continue
            data = PeakingTimeRun(detid = detid,\
                                  channel = ch,\
                                  fwhm = params[1],\
                                  fwhm_err = params_err['fwhm0'],\
                                  ptime = ptime,\
                                  figname = figname)
            datasets.append(data)
        logger.info(f'In total, {len(datasets)} datasets were processed ')

    # prepare all the plots we want to combine in 
    # one for a single peaking time
    pktimes_savenames = dict([(dsptime.ptime,[]) for dsptime in datasets])
    # create the model fit plots
    ch_savenames = dict(zip([ch for ch in range(8)],[[] for ch in range(8)])) 

    # prepare a plot for the asic projections 
    # for this detector
    if args.asic_projection:
        asic_pr_plot = p.figure(dpi=120, figsize=lo.FIGSIZE_A4_LANDSCAPE_HALF_HEIGHT)
        asic_max = 0

    for ch in active_channels:
        stripname = get_stripname(ch)
        ch_datasets = [data for data in datasets if data.channel == ch]
        logger.info(f'{len(ch_datasets)} datasets available for channel {ch}')
        xs = np.array([data.ptime for data in ch_datasets])
        ys = np.array([data.fwhm for data in ch_datasets])
        ys_err = np.array([data.fwhm_err for data in ch_datasets])

        figures = [data.figname for data in ch_datasets]
        
        # sort them from smallest to largest peaking time
        idx     = np.argsort(np.array(xs))
        xs      = np.array(xs)[idx]
        ys      = np.array(ys)[idx]
        ys_err  = np.array(ys_err)[idx]

        try:
            __,  thisnoisemodel = fit_noisemodel(xs, ys, ys_err, ch, detid,\
                                                 use_minuit=True, plotdir=args.outdir,\
                                                 method=method)
        except Exception as e:
            logger.warning(f"Can not fit noisemodel for {detid} ch {ch}, exception {e}")
            continue

        if args.asic_projection:
            try:
                # FIXME: units!
                thisasic = enc2_asic(thisnoisemodel.xs,\
                                     thisnoisemodel.I_L*(1e-9),\
                                     thisnoisemodel.A_f*(1e-13),
                                     C=args.detector_capacitance*(1e-12))
                # limit the asic projection plot to (0.49, 2)
                xrange_mask = np.logical_and(thisnoisemodel.xs > 0.49,thisnoisemodel.xs < 2) 
                if max(thisasic[xrange_mask]) > asic_max:
                    asic_max = max(thisasic[xrange_mask])

                asic_pr_plot  = asic_projection_plot(thisnoisemodel.xs, thisasic,
                                                     thisnoisemodel.stripname,
                                                     thisnoisemodel.detid,
                                                     fig=asic_pr_plot)
            
            except Exception as e:
                    print (f"Can not plot asic projection for {detid} ch {ch}, exception {e}")

        # save the individual figure
        for i,fig in tqdm.tqdm(enumerate(figures), desc=f'Saving data for {get_stripname(ch)}', total=len(figures)):
            #pngfilename = f'det{detid}-stime{xs[i]}-{stripname}.png' 
            #logger.info(f'Saving {pngfilename} ... ')
            #fig.savefig(os.path.join(args.outdir,pngfilename))
            #p.close(fig)
            pngfilename = fig
            ch_savenames[ch].append(os.path.join(args.outdir, pngfilename))
            pktimes_savenames[xs[i]].append(os.path.join(args.outdir, pngfilename))

    
    # save the asic projection plot
    if args.asic_projection:
        ax = asic_pr_plot.gca()
        # a thin line for xray/tracking 
        ax.hlines(4,0,2, color="gray", alpha=0.7, lw=0.5)
        print (f'Found max asic {asic_max}')
        ax.set_ylim(top=asic_max + (0.1*asic_max))
        ax.legend(ncol=4,fontsize='small')
        asicplotname = f'asic-pr-det{detid}.png'
        asicplotname = os.path.join(args.outdir, asicplotname)
        asic_pr_plot.savefig(asicplotname)

        
    # let's write a textfile as well
    summaryfile = open(os.path.join(args.outdir, f'summary-{detid}.dat'), 'w')
    summaryfile.write(f'# detector {detid}\n')
    summaryfile.write(f'# strip - peaking time  - fwhm - fwhm_err\n')
    # FIXME - this fails in case we have more than one
    # detector, in that case however the whole script will blow up
    for ds in datasets:
        summaryfile.write(f'{detid}\t{get_stripname(ds.channel)}\t{ds.ptime:4.2f}\t{ds.fwhm:4.2f}\t{ds.fwhm_err:4.2f}\n')
    summaryfile.close()




    # also combine all the figures for a specific peaking time in 
    # one panel with imagemagick
    # this is a nice feature, however sometimes it can consume
    # a lot of memory
    if args.combine_images:
        # combine all the figures for a specific ch in 
        # one panel with imagemagic
        for ch in ch_savenames.keys():
            mplotname = os.path.join(args.outdir,f'det{detid}-{get_stripname(ch)}')
            command = f'montage -geometry +4+4 -tile 5x2 '
            for gfx in ch_savenames[ch]:
                command += f' {gfx} '
            command += f' {mplotname}'  

            logger.info(f'Executing {command}')
            proc = subprocess.Popen(shlex.split(command + '.pdf')).communicate()
            proc = subprocess.Popen(shlex.split(command + '.png')).communicate()
            proc = subprocess.Popen(shlex.split(command + '.jpg')).communicate()
           
            # shrink the jpg so it will go in google docs presentation
            command_shrink = f'convert {mplotname}.jpg -resize 50% -normalize  {mplotname}.jpg'
            proc = subprocess.Popen(shlex.split(command_shrink)).communicate()

        for pktime in pktimes_savenames.keys():

            mplotname = os.path.join(args.outdir,f'det{detid}-pktime{pktime}')
            command = f'montage -geometry +4+4 -tile 4x2 '
            for gfx in pktimes_savenames[pktime]:
                command += f' {gfx} '
            command += f' {mplotname}'  

            logger.info(f'Executing {command}')
            proc = subprocess.Popen(shlex.split(command + '.pdf ')).communicate()
            proc = subprocess.Popen(shlex.split(command + '.png ')).communicate()
            proc = subprocess.Popen(shlex.split(command + '.jpg ')).communicate()
            command_shrink = f' convert {mplotname}.jpg -resize 50% -normalize  {mplotname}.jpg'
            proc = subprocess.Popen(shlex.split(command_shrink)).communicate()


    print (f'Folder {args.indir} analyzed')
