# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/11_shared.ipynb (unless otherwise specified).

__all__ = ['SharedCircArrayBuffer', 'SharedDataCube', 'save_shared_datacube', 'SharedOpenHSI']

# Cell

from fastcore.foundation import patch
from fastcore.meta import delegates
import numpy as np
import ctypes
import matplotlib.pyplot as plt
from tqdm import tqdm
from typing import Iterable, Union, Callable, List, TypeVar, Generic, Tuple, Optional, Dict
from functools import reduce
from pathlib import Path
import xarray as xr

# Cell

from .data import *

from ctypes import c_int32, c_uint32, c_float
from multiprocessing import Process, Queue, Array

# Cell


class SharedCircArrayBuffer(CircArrayBuffer):
    """Circular FIFO Buffer implementation on multiprocessing.Array. Each put/get is a (n-1)darray."""

    def __init__(self, size:tuple = (100,100), axis:int = 0, c_dtype:type = c_int32, show_func:Callable[[np.ndarray],"plot"] = None):
        """Preallocate a array of `size` and type `c_dtype` and init write/read pointer. `c_dtype` needs to be from ctypes"""

        self.shared_data = Array(c_dtype, reduce(lambda x,y: x*y, size) )
        self.data = np.frombuffer(self.shared_data.get_obj(),dtype=c_dtype)
        self.data = self.data.reshape(size)

        self.size = size
        self.axis = axis
        self.write_pos = [slice(None,None,None) if i != axis else 0 for i in range(len(size)) ]
        self.read_pos  = self.write_pos.copy()
        self.slots_left = self.size[self.axis]
        self.show_func = show_func

@delegates()
class SharedDataCube(CameraProperties):
    """Facilitates the collection, viewing, and saving of hyperspectral datacubes using
    two `SharedCircArrayBuffer`s that swap when save is called."""

    def __init__(self, n_lines:int = 16, processing_lvl:int = -1, **kwargs):
        """Preallocate array buffers"""
        self.n_lines = n_lines
        self.proc_lvl = processing_lvl
        super().__init__(**kwargs)
        self.set_processing_lvl(processing_lvl)
        self.dc_shape = (self.dc_shape[0],self.n_lines,self.dc_shape[1])
        self.dtype_out = c_int32 if self.dtype_out is np.int32 else self.dtype_out
        self.dtype_out = c_float if self.dtype_out is np.float32 else self.dtype_out

        # Only one set of buffers can be used at a time
        self.timestamps_swaps = [DateTimeBuffer(n_lines), DateTimeBuffer(n_lines)]
        self.dc_swaps         = [SharedCircArrayBuffer(size=self.dc_shape, axis=1, c_dtype=self.dtype_out),
                                 SharedCircArrayBuffer(size=self.dc_shape, axis=1, c_dtype=self.dtype_out)]
        print(f"Allocated {2*4*reduce(lambda x,y: x*y, self.dc_shape)/2**20:.02f} MB of RAM.")

        self.current_swap = 0
        self.timestamps   = self.timestamps_swaps[self.current_swap]
        self.dc           = self.dc_swaps[self.current_swap]

    def __repr__(self):
        return f"DataCube: shape = {self.dc_shape}, Processing level = {self.proc_lvl}\n"

    def put(self, x:np.ndarray):
        """Applies the composed tranforms and writes the 2D array into the data cube. Stores a timestamp for each push."""
        self.timestamps.update()
        self.dc.put( self.pipeline(x) )

@patch
def save(self:SharedDataCube, save_dir:str, preconfig_meta_path:str=None, prefix:str="", suffix:str="") -> Process:
    """Saves to a NetCDF file (and RGB representation) to directory dir_path in folder given by date with file name given by UTC time.
    Save is done in a separate multiprocess.Process."""
    if preconfig_meta_path is not None:
        with open(preconfig_meta_path) as json_file:
            attrs = json.load(json_file)
    else: attrs = {}

    self.directory = Path(f"{save_dir}/{self.timestamps[0].strftime('%Y_%m_%d')}/").mkdir(parents=False, exist_ok=True)
    self.directory = f"{save_dir}/{self.timestamps[0].strftime('%Y_%m_%d')}"

    if hasattr(self, "binned_wavelengths"):
        wavelengths = self.binned_wavelengths if self.proc_lvl not in (3,7,8) else self.λs
    else:
        wavelengths = np.arange(self.dc.data.shape[2])

    if hasattr(self,"cam_temperatures"):
        self.coords = dict(wavelength=(["wavelength"],wavelengths),
                           x=(["x"],np.arange(self.dc.data.shape[0])),
                           y=(["y"],np.arange(self.dc.data.shape[1])),
                           time=(["time"],self.timestamps.data.astype(np.datetime64)),
                           temperature=(["temperature"],self.cam_temperatures.data))
    else:
        self.coords = dict(wavelength=(["wavelength"],wavelengths),
                           x=(["x"],np.arange(self.dc.data.shape[0])),
                           y=(["y"],np.arange(self.dc.data.shape[1])),
                           time=(["time"],self.timestamps.data.astype(np.datetime64)))

    fname = f"{self.directory}/{prefix}{self.timestamps[0].strftime('%Y_%m_%d-%H_%M_%S')}{suffix}"

    p = Process(target=save_shared_datacube, args=(fname,self.dc.shared_data,self.dtype_out,self.dc.size,self.coords,attrs,self.proc_lvl))
    p.start()
    print(f"Saving {fname} in another process.")

    self.current_swap = 0 if self.current_swap == 1 else 1
    self.timestamps   = self.timestamps_swaps[self.current_swap]
    self.dc           = self.dc_swaps[self.current_swap]
    if hasattr(self,"cam_temperatures"):
        self.cam_temperatures = self.cam_temps_swaps[self.current_swap]
    return p


def save_shared_datacube(fname:str,shared_array:Array,c_dtype:type,shape:Tuple,coords_dict:Dict,attrs_dict:Dict,proc_lvl:int):

    data = np.frombuffer(shared_array.get_obj(),dtype=c_dtype)
    data = data.reshape(shape)

    nc = xr.Dataset(data_vars=dict(datacube=(["wavelength","x","y"],np.moveaxis(data, -1, 0) )),
                         coords=coords_dict, attrs=attrs_dict)

    """provide metadata to NetCDF coordinates"""
    nc.x.attrs["long_name"]   = "cross-track"
    nc.x.attrs["units"]       = "pixels"
    nc.x.attrs["description"] = "cross-track spatial coordinates"
    nc.y.attrs["long_name"]   = "along-track"
    nc.y.attrs["units"]       = "pixels"
    nc.y.attrs["description"] = "along-track spatial coordinates"
    nc.time.attrs["long_name"]   = "along-track"
    nc.time.attrs["description"] = "along-track spatial coordinates"
    nc.wavelength.attrs["long_name"]   = "wavelength_nm"
    nc.wavelength.attrs["units"]       = "nanometers"
    nc.wavelength.attrs["description"] = "wavelength in nanometers."

    if "temperature" in coords_dict.keys():
        nc.temperature.attrs["long_name"] = "camera temperature"
        nc.temperature.attrs["units"] = "degrees Celsius"
        nc.temperature.attrs["description"] = "temperature of sensor at time of image capture"

    nc.datacube.attrs["long_name"]   = "hyperspectral datacube"
    nc.datacube.attrs["units"]       = "digital number"
    if proc_lvl in (4,5,7): nc.datacube.attrs["units"] = "uW/cm^2/sr/nm"
    elif proc_lvl in (6,8): nc.datacube.attrs["units"] = "percentage reflectance"
    nc.datacube.attrs["description"] = "hyperspectral datacube"

    nc.to_netcdf(fname+".nc")

    import holoviews as hv
    hv.extension("bokeh",logo=False)
    dc = DataCube()
    dc.load_nc(fname+".nc")
    hv.save(dc.show("matplotlib",robust=True),fname+".png")



# Cell

@delegates()
class SharedOpenHSI(SharedDataCube):
    """Base Class for the OpenHSI Camera."""
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        super().set_processing_lvl(self.proc_lvl)
        if callable(getattr(self,"get_temp",None)):
            self.cam_temps_swaps  = [CircArrayBuffer(size=(self.n_lines,),dtype=np.float32),
                                     CircArrayBuffer(size=(self.n_lines,),dtype=np.float32)]
            self.cam_temperatures = self.cam_temps_swaps[self.current_swap]

    def __enter__(self):
        return self

    def __close__(self):
        self.stop_cam()

    def __exit__(self, exc_type, exc_value, traceback):
        self.stop_cam()

    def collect(self):
        """Collect the hyperspectral datacube."""
        self.start_cam()
        for i in tqdm(range(self.n_lines)):
            self.put(self.get_img())

            if callable(getattr(self,"get_temp",None)):
                self.cam_temperatures.put( self.get_temp() )
        self.stop_cam()

    def avgNimgs(self, n) -> np.ndarray:
        """Take `n` images and find the average"""
        data = np.zeros(tuple(self.settings['resolution'])+(n,),np.int32)

        self.start_cam()
        for f in range(n):
            data[:,:,f]=self.get_img()
        self.stop_cam()
        return np.mean(data,axis=2)
