"""Authors: Saksham Sharda and Alessio Buccino."""
import os
from pathlib import Path
from warnings import warn
from typing import Optional
from copy import deepcopy

import numpy as np

from roiextractors import ImagingExtractor, SegmentationExtractor, MultiSegmentationExtractor
from pynwb import NWBFile, NWBHDF5IO
from pynwb.base import Images
from pynwb.file import Subject
from pynwb.image import GrayscaleImage
from pynwb.device import Device
from pynwb.ophys import (
    ImageSegmentation,
    Fluorescence,
    OpticalChannel,
    TwoPhotonSeries,
)

# from hdmf.commmon import VectorData
from hdmf.data_utils import DataChunkIterator
from hdmf.backends.hdf5.h5_utils import H5DataIO

from ..nwb_helpers import get_default_nwbfile_metadata, make_nwbfile_from_metadata, make_or_load_nwbfile
from ...utils import (
    FilePathType,
    OptionalFilePathType,
    dict_deep_update,
    calculate_regular_series_rate,
)


def get_default_ophys_metadata():
    """Fill default metadata for optical physiology."""
    metadata = get_default_nwbfile_metadata()

    default_device = dict(name="Microscope")

    metadata.update(
        Ophys=dict(
            Device=[default_device],
            Fluorescence=dict(
                roi_response_series=[
                    dict(
                        name="RoiResponseSeries",
                        description="array of raw fluorescence traces",
                    )
                ]
            ),
            ImageSegmentation=dict(plane_segmentations=[dict(description="Segmented ROIs", name="PlaneSegmentation")]),
            ImagingPlane=[
                dict(
                    name="ImagingPlane",
                    description="no description",
                    excitation_lambda=np.nan,
                    indicator="unknown",
                    location="unknown",
                    device=default_device["name"],
                    optical_channel=[
                        dict(
                            name="OpticalChannel",
                            emission_lambda=np.nan,
                            description="no description",
                        )
                    ],
                )
            ],
            TwoPhotonSeries=[
                dict(
                    name="TwoPhotonSeries",
                    description="no description",
                    comments="Generalized from RoiInterface",
                    unit="n.a.",
                )
            ],
        ),
    )
    return metadata


def get_nwb_imaging_metadata(imgextractor: ImagingExtractor):
    """
    Convert metadata from the ImagingExtractor into nwb specific metadata.

    Parameters
    ----------
    imgextractor: ImagingExtractor
    """
    metadata = get_default_ophys_metadata()
    # Optical Channel name:
    channel_name_list = imgextractor.get_channel_names()
    if channel_name_list is None:
        channel_name_list = ["generic_name"] * imgextractor.get_num_channels()

    for index, channel_name in enumerate(channel_name_list):
        if index == 0:
            metadata["Ophys"]["ImagingPlane"][0]["optical_channel"][index]["name"] = channel_name
        else:
            metadata["Ophys"]["ImagingPlane"][0]["optical_channel"].append(
                dict(
                    name=channel_name,
                    emission_lambda=np.nan,
                    description=f"{channel_name} description",
                )
            )
    # set imaging plane rate:
    rate = np.nan if imgextractor.get_sampling_frequency() is None else float(imgextractor.get_sampling_frequency())

    # adding imaging_rate:
    metadata["Ophys"]["ImagingPlane"][0].update(imaging_rate=rate)
    # TwoPhotonSeries update:
    metadata["Ophys"]["TwoPhotonSeries"][0].update(dimension=list(imgextractor.get_image_size()), rate=rate)

    plane_name = metadata["Ophys"]["ImagingPlane"][0]["name"]
    metadata["Ophys"]["TwoPhotonSeries"][0]["imaging_plane"] = plane_name

    # remove what Segmentation extractor will input:
    _ = metadata["Ophys"].pop("ImageSegmentation")
    _ = metadata["Ophys"].pop("Fluorescence")
    return metadata


def add_devices(nwbfile: NWBFile, metadata: dict) -> NWBFile:
    """Add optical physiology devices from metadata."""
    metadata_copy = deepcopy(metadata)
    default_metadata = get_default_ophys_metadata()
    metadata_copy = dict_deep_update(default_metadata, metadata_copy, append_list=False)
    device_metadata = metadata_copy["Ophys"]["Device"]

    for device in device_metadata:
        device = Device(**device) if isinstance(device, dict) else device
        if device.name not in nwbfile.devices:
            nwbfile.add_device(device)

    return nwbfile


def _add_imaging_plane(nwbfile: NWBFile, metadata=dict) -> NWBFile:

    metadata_copy = deepcopy(metadata)
    add_devices(nwbfile=nwbfile, metadata=metadata_copy)

    # Add the image plane
    image_plane_metadata = metadata_copy["Ophys"]["ImagingPlane"][0]

    device_name = image_plane_metadata["device"]
    image_plane_metadata["device"] = nwbfile.devices[device_name]

    image_plane_metadata["optical_channel"] = [
        OpticalChannel(**metadata) for metadata in image_plane_metadata["optical_channel"]
    ]

    nwbfile.create_imaging_plane(**image_plane_metadata)

    return nwbfile


def add_two_photon_series(imaging, nwbfile, metadata, buffer_size=10, use_times=False):
    """
    Auxiliary static method for nwbextractor.

    Adds two photon series from imaging object as TwoPhotonSeries to nwbfile object.
    """

    if use_times:
        warn("Keyword argument 'use_times' is deprecated and will be removed on or after August 1st, 2022.")

    metadata_copy = deepcopy(metadata)
    metadata_copy = dict_deep_update(get_nwb_imaging_metadata(imaging), metadata_copy, append_list=False)

    # Tests if TwoPhotonSeries already exists in acquisition
    two_photon_series_metadata = metadata_copy["Ophys"]["TwoPhotonSeries"][0]
    two_photon_series_name = two_photon_series_metadata["name"]

    if two_photon_series_name in nwbfile.acquisition:
        warn(f"{two_photon_series_name} already on nwbfile")
        return nwbfile

    # Add the image plane to nwb.
    nwbfile = _add_imaging_plane(nwbfile=nwbfile, metadata=metadata_copy)

    # Add the data
    def data_generator(imaging):
        for i in range(imaging.get_num_frames()):
            yield imaging.get_frames(frame_idxs=[i]).squeeze().T

    data = H5DataIO(
        DataChunkIterator(data_generator(imaging), buffer_size=buffer_size),
        compression=True,
    )
    two_p_series_kwargs = two_photon_series_metadata
    two_p_series_kwargs.update(data=data)

    # Add dimension
    two_p_series_kwargs.update(dimension=imaging.get_image_size())

    # Add timestamps or rate
    timestamps = imaging.frame_to_time(np.arange(imaging.get_num_frames()))
    rate = calculate_regular_series_rate(series=timestamps)
    if rate:
        two_p_series_kwargs.update(starting_time=timestamps[0], rate=rate)
    else:
        two_p_series_kwargs.update(timestamps=H5DataIO(timestamps, compression="gzip"))
        two_p_series_kwargs["rate"] = None

    # Add imaging plane
    imaging_plane_name = two_photon_series_metadata["imaging_plane"]
    imaging_plane = nwbfile.get_imaging_plane(name=imaging_plane_name)
    two_p_series_kwargs.update(imaging_plane=imaging_plane)

    # Add the TwoPhotonSeries to the nwbfile
    two_photon_series = TwoPhotonSeries(**two_p_series_kwargs)
    nwbfile.add_acquisition(two_photon_series)

    return nwbfile


def add_epochs(imaging, nwbfile):
    """
    Auxiliary static method for nwbextractor.

    Adds epochs from recording object to nwbfile object.
    """
    # add/update epochs
    for (name, ep) in imaging._epochs.items():
        if nwbfile.epochs is None:
            nwbfile.add_epoch(
                start_time=imaging.frame_to_time(ep["start_frame"]),
                stop_time=imaging.frame_to_time(ep["end_frame"]),
                tags=name,
            )
        else:
            if [name] in nwbfile.epochs["tags"][:]:
                ind = nwbfile.epochs["tags"][:].index([name])
                nwbfile.epochs["start_time"].data[ind] = imaging.frame_to_time(ep["start_frame"])
                nwbfile.epochs["stop_time"].data[ind] = imaging.frame_to_time(ep["end_frame"])
            else:
                nwbfile.add_epoch(
                    start_time=imaging.frame_to_time(ep["start_frame"]),
                    stop_time=imaging.frame_to_time(ep["end_frame"]),
                    tags=name,
                )
    return nwbfile


def write_imaging(
    imaging: ImagingExtractor,
    nwbfile_path: OptionalFilePathType = None,
    nwbfile: Optional[NWBFile] = None,
    metadata: Optional[dict] = None,
    overwrite: bool = False,
    verbose: bool = True,
    buffer_size: int = 10,
    use_times=False,
    save_path: OptionalFilePathType = None,  # TODO: to be removed
):
    """
    Primary method for writing an ImagingExtractor object to an NWBFile.

    Parameters
    ----------
    imaging: ImagingExtractor
        The imaging extractor object to be written to nwb
    nwbfile_path: FilePathType
        Path for where to write or load (if overwrite=False) the NWBFile.
        If specified, the context will always write to this location.
    nwbfile: NWBFile, optional
        If passed, this function will fill the relevant fields within the NWBFile object.
        E.g., calling
            write_recording(recording=my_recording_extractor, nwbfile=my_nwbfile)
        will result in the appropriate changes to the my_nwbfile object.
        If neither 'save_path' nor 'nwbfile' are specified, an NWBFile object will be automatically generated
        and returned by the function.
    metadata: dict, optional
        Metadata dictionary with information used to create the NWBFile when one does not exist or overwrite=True.
    overwrite: bool, optional
        Whether or not to overwrite the NWBFile if one exists at the nwbfile_path.
        The default is False (append mode).
    verbose: bool, optional
        If 'nwbfile_path' is specified, informs user after a successful write operation.
        The default is True.
    num_chunks: int
        Number of chunks for writing data to file
    """
    assert save_path is None or nwbfile is None, "Either pass a save_path location, or nwbfile object, but not both!"
    if nwbfile is not None:
        assert isinstance(nwbfile, NWBFile), "'nwbfile' should be of type pynwb.NWBFile"

    if use_times:
        warn("Keyword argument 'use_times' is deprecated and will be removed on or after August 1st, 2022.")

    # TODO on or after August 1st, 2022, remove argument and deprecation warnings
    if save_path is not None:
        will_be_removed_str = "will be removed on or after August 1st, 2022. Please use 'nwbfile_path' instead."
        if nwbfile_path is not None:
            if save_path == nwbfile_path:
                warn(
                    "Passed both 'save_path' and 'nwbfile_path', but both are equivalent! "
                    f"'save_path' {will_be_removed_str}",
                    DeprecationWarning,
                )
            else:
                warn(
                    "Passed both 'save_path' and 'nwbfile_path' - using only the 'nwbfile_path'! "
                    f"'save_path' {will_be_removed_str}",
                    DeprecationWarning,
                )
        else:
            warn(
                f"The keyword argument 'save_path' to 'spikeinterface.write_recording' {will_be_removed_str}",
                DeprecationWarning,
            )
            nwbfile_path = save_path

    if metadata is None:
        metadata = dict()
    if hasattr(imaging, "nwb_metadata"):
        metadata = dict_deep_update(imaging.nwb_metadata, metadata, append_list=False)

    with make_or_load_nwbfile(
        nwbfile_path=nwbfile_path, nwbfile=nwbfile, metadata=metadata, overwrite=overwrite, verbose=verbose
    ) as nwbfile_out:
        add_devices(nwbfile=nwbfile_out, metadata=metadata)
        add_two_photon_series(imaging=imaging, nwbfile=nwbfile_out, metadata=metadata, buffer_size=buffer_size)
        add_epochs(imaging=imaging, nwbfile=nwbfile_out)
    return nwbfile_out


def get_nwb_segmentation_metadata(sgmextractor):
    """
    Convert metadata from the segmentation into nwb specific metadata.

    Parameters
    ----------
    sgmextractor: SegmentationExtractor
    """
    metadata = get_default_ophys_metadata()
    # Optical Channel name:
    for i in range(sgmextractor.get_num_channels()):
        ch_name = sgmextractor.get_channel_names()[i]
        if i == 0:
            metadata["Ophys"]["ImagingPlane"][0]["optical_channel"][i]["name"] = ch_name
        else:
            metadata["Ophys"]["ImagingPlane"][0]["optical_channel"].append(
                dict(
                    name=ch_name,
                    emission_lambda=np.nan,
                    description=f"{ch_name} description",
                )
            )
    # set roi_response_series rate:
    rate = np.nan if sgmextractor.get_sampling_frequency() is None else sgmextractor.get_sampling_frequency()
    for trace_name, trace_data in sgmextractor.get_traces_dict().items():
        if trace_name == "raw":
            if trace_data is not None:
                metadata["Ophys"]["Fluorescence"]["roi_response_series"][0].update(rate=rate)
            continue
        if trace_data is not None and len(trace_data.shape) != 0:
            metadata["Ophys"]["Fluorescence"]["roi_response_series"].append(
                dict(
                    name=trace_name.capitalize(),
                    description=f"description of {trace_name} traces",
                    rate=rate,
                )
            )
    # adding imaging_rate:
    metadata["Ophys"]["ImagingPlane"][0].update(imaging_rate=rate)
    # remove what imaging extractor will input:
    _ = metadata["Ophys"].pop("TwoPhotonSeries")
    return metadata


def write_segmentation(
    segext_obj: SegmentationExtractor,
    save_path: FilePathType = None,
    plane_num=0,
    metadata: dict = None,
    overwrite: bool = True,
    buffer_size: int = 10,
    nwbfile=None,
):
    assert save_path is None or nwbfile is None, "Either pass a save_path location, or nwbfile object, but not both!"

    # parse metadata correctly:
    if isinstance(segext_obj, MultiSegmentationExtractor):
        segext_objs = segext_obj.segmentations
        if metadata is not None:
            assert isinstance(metadata, list), (
                "For MultiSegmentationExtractor enter 'metadata' as a list of " "SegmentationExtractor metadata"
            )
            assert len(metadata) == len(segext_objs), (
                "The 'metadata' argument should be a list with the same "
                "number of elements as the segmentations in the "
                "MultiSegmentationExtractor"
            )
    else:
        segext_objs = [segext_obj]
        if metadata is not None and not isinstance(metadata, list):
            metadata = [metadata]
    metadata_base_list = [get_nwb_segmentation_metadata(sgobj) for sgobj in segext_objs]
    # updating base metadata with new:
    for num, data in enumerate(metadata_base_list):
        metadata_input = metadata[num] if metadata else {}
        metadata_base_list[num] = dict_deep_update(metadata_base_list[num], metadata_input, append_list=False)
    metadata_base_common = metadata_base_list[0]

    # build/retrieve nwbfile:
    if nwbfile is not None:
        assert isinstance(nwbfile, NWBFile), "'nwbfile' should be of type pynwb.NWBFile"
        write = False
    else:
        write = True
        save_path = Path(save_path)
        assert save_path.suffix == ".nwb"
        if save_path.is_file() and not overwrite:
            nwbfile_exist = True
            file_mode = "r+"
        else:
            if save_path.is_file():
                os.remove(save_path)
            if not save_path.parent.is_dir():
                save_path.parent.mkdir(parents=True)
            nwbfile_exist = False
            file_mode = "w"
        io = NWBHDF5IO(str(save_path), file_mode)
        if nwbfile_exist:
            nwbfile = io.read()
        else:
            nwbfile = make_nwbfile_from_metadata(metadata=metadata_base_common)
    # Subject:
    if metadata_base_common.get("Subject") and nwbfile.subject is None:
        nwbfile.subject = Subject(**metadata_base_common["Subject"])
    # Processing Module:
    if "ophys" not in nwbfile.processing:
        ophys = nwbfile.create_processing_module("ophys", "contains optical physiology processed data")
    else:
        ophys = nwbfile.get_processing_module("ophys")
    for plane_no_loop, (segext_obj, metadata) in enumerate(zip(segext_objs, metadata_base_list)):
        # Device:
        if metadata["Ophys"]["Device"][0]["name"] not in nwbfile.devices:
            nwbfile.create_device(**metadata["Ophys"]["Device"][0])
        # ImageSegmentation:
        image_segmentation_name = (
            "ImageSegmentation" if plane_no_loop == 0 else f"ImageSegmentation_Plane{plane_no_loop}"
        )
        if image_segmentation_name not in ophys.data_interfaces:
            image_segmentation = ImageSegmentation(name=image_segmentation_name)
            ophys.add(image_segmentation)
        else:
            image_segmentation = ophys.data_interfaces.get(image_segmentation_name)
        # OpticalChannel:
        optical_channels = [OpticalChannel(**i) for i in metadata["Ophys"]["ImagingPlane"][0]["optical_channel"]]

        # ImagingPlane:
        image_plane_name = "ImagingPlane" if plane_no_loop == 0 else f"ImagePlane_{plane_no_loop}"
        if image_plane_name not in nwbfile.imaging_planes.keys():
            input_kwargs = dict(
                name=image_plane_name,
            )
            metadata["Ophys"]["ImagingPlane"][0]["optical_channel"] = optical_channels
            input_kwargs.update(**metadata["Ophys"]["ImagingPlane"][0])
            input_kwargs.update(device=nwbfile.get_device(metadata_base_common["Ophys"]["Device"][0]["name"]))
            if "imaging_rate" in input_kwargs:
                input_kwargs["imaging_rate"] = float(input_kwargs["imaging_rate"])
            imaging_plane = nwbfile.create_imaging_plane(**input_kwargs)
        else:
            imaging_plane = nwbfile.imaging_planes[image_plane_name]
        # PlaneSegmentation:
        input_kwargs = dict(
            description="output from segmenting imaging plane",
            imaging_plane=imaging_plane,
        )
        ps_metadata = metadata["Ophys"]["ImageSegmentation"]["plane_segmentations"][0]
        if ps_metadata["name"] not in image_segmentation.plane_segmentations:
            ps_exist = False
        else:
            ps = image_segmentation.get_plane_segmentation(ps_metadata["name"])
            ps_exist = True
        roi_ids = segext_obj.get_roi_ids()
        accepted_list = segext_obj.get_accepted_list()
        accepted_list = [] if accepted_list is None else accepted_list
        rejected_list = segext_obj.get_rejected_list()
        rejected_list = [] if rejected_list is None else rejected_list
        accepted_ids = [1 if k in accepted_list else 0 for k in roi_ids]
        rejected_ids = [1 if k in rejected_list else 0 for k in roi_ids]
        roi_locations = np.array(segext_obj.get_roi_locations()).T

        def image_mask_iterator():
            for id in segext_obj.get_roi_ids():
                img_msks = segext_obj.get_roi_image_masks(roi_ids=[id]).T.squeeze()
                yield img_msks

        if not ps_exist:
            from hdmf.common import VectorData

            input_kwargs.update(
                **ps_metadata,
                columns=[
                    VectorData(
                        data=H5DataIO(
                            DataChunkIterator(image_mask_iterator(), buffer_size=buffer_size),
                            compression=True,
                            compression_opts=9,
                        ),
                        name="image_mask",
                        description="image masks",
                    ),
                    VectorData(
                        data=roi_locations,
                        name="RoiCentroid",
                        description="x,y location of centroid of the roi in image_mask",
                    ),
                    VectorData(
                        data=accepted_ids,
                        name="Accepted",
                        description="1 if ROi was accepted or 0 if rejected as a cell during segmentation operation",
                    ),
                    VectorData(
                        data=rejected_ids,
                        name="Rejected",
                        description="1 if ROi was rejected or 0 if accepted as a cell during segmentation operation",
                    ),
                ],
                id=roi_ids,
            )

            ps = image_segmentation.create_plane_segmentation(**input_kwargs)

        # Fluorescence Traces - This should be a function on its own.
        if "Flourescence" not in ophys.data_interfaces:
            fluorescence = Fluorescence()
            ophys.add(fluorescence)
        else:
            fluorescence = ophys.data_interfaces["Fluorescence"]

        roi_table_region = ps.create_roi_table_region(
            description=f"region for Imaging plane{plane_no_loop}",
            region=list(range(segext_obj.get_num_rois())),
        )

        roi_response_dict = segext_obj.get_traces_dict()

        rate = np.nan if segext_obj.get_sampling_frequency() is None else segext_obj.get_sampling_frequency()

        # Filter empty data
        roi_response_dict = {key: value for key, value in roi_response_dict.items() if value is not None}
        for signal, response_series in roi_response_dict.items():

            not_all_data_is_zero = any(x != 0 for x in np.ravel(response_series))
            if not_all_data_is_zero:
                data = np.asarray(response_series)
                trace_name = "RoiResponseSeries" if signal == "raw" else signal.capitalize()
                trace_name = trace_name if plane_no_loop == 0 else trace_name + f"_Plane{plane_no_loop}"
                input_kwargs = dict(
                    name=trace_name,
                    data=data.T,
                    rois=roi_table_region,
                    rate=rate,
                    unit="n.a.",
                )
                if trace_name not in fluorescence.roi_response_series:
                    fluorescence.create_roi_response_series(**input_kwargs)

        # create Two Photon Series:
        if "TwoPhotonSeries" not in nwbfile.acquisition:
            warn("could not find TwoPhotonSeries, using ImagingExtractor to create an nwbfile")
        # adding images:
        images_dict = segext_obj.get_images_dict()
        if any([image is not None for image in images_dict.values()]):
            images_name = "SegmentationImages" if plane_no_loop == 0 else f"SegmentationImages_Plane{plane_no_loop}"
            if images_name not in ophys.data_interfaces:
                images = Images(images_name)
                for img_name, img_no in images_dict.items():
                    if img_no is not None:
                        images.add_image(GrayscaleImage(name=img_name, data=img_no.T))
                ophys.add(images)

        # saving NWB file:
        if write:
            io.write(nwbfile)
            io.close()
