# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/25_cv_data.ipynb (unless otherwise specified).

__all__ = ['PathStr', 'TiffImage']

# Cell
# Import packages and modules required for cv.data module
import matplotlib.pyplot as plt
import numpy as np
import os
import string

from pathlib import Path
from PIL import Image, ImageSequence
from PIL.TiffImagePlugin import TiffImageFile, TiffTags, ImageFileDirectory_v2
from typing import Any, Dict, List, Optional, Union

PathStr = Union[Path, str]

# Cell
class TiffImage:
    """Load and handle TIFF images, such as viewing frames, extracting frames and saving them as single image files.

    TiffImage loads a TIFF image from a path (Path or str) and return a TiffImage object that:
    - gives access to the number of frames 'n_frames', all the tags 'tags' and the image size in pixel 'size'.
    - allows to return a specific frame as an object
    - allows to extract and save all or any frame as 'tif' or 'jpg image file(s)'
    - provides a __repr__ including information on each the frames in the image
    - allows to show thumbnails of all frames in a grid.

    Abreviations in the code are according to https://docs.fast.ai/dev/abbr.html

    """

    VALID_CHARS = f"-_.() {string.ascii_letters}{string.digits}"

    def __init__(self, path: PathStr):
        """Loads the TIFF image from system path"""
        self.path = self.handle_pathstr_(path)
        if self.path.suffix not in ['.tif', '.tiff']:
            raise ValueError(f"Image file should be .tif or .tiff, not '{self.path.suffix}'")
        self.tiff = TiffImageFile(self.path)

    def __repr__(self):
        """Return a summary of all the frames in the image file"""
        str_lst = [f"<unpackai.cv.data.TiffImage> TIFF image file with {self.n_frames} frames."]
        str_lst.append(f"  Loaded from {self.path}.")
        str_lst.append('  Frame Content Summary:')
        str_lst.append(f"    {'Frame':^7s}|{'Size':^15s}|{'Nbr Tags':^10s}")
        for i, frame in enumerate(ImageSequence.Iterator(self.tiff)):
            str_lst.append(f"    {str(i):^7s}|{str(frame.size):^15}|{str(len(frame.tag_v2.keys())):^10s}")
        str_lst.append(f"  To show image thumbnails, use method '.show(frame_nbr)' of '.show_all()'")
        txt = '\n'.join(str_lst)
        return txt

    def summary_tags(self, frame_nbr: int = 0):
        """Display a summary table of all the tags in the selected frame"""
        if self.is_valid_frame_(frame_nbr):
            self.tiff.seek(frame_nbr)
        tags_predefined = [tag for tag in self.tiff.tag_v2.keys() if tag in list(TiffTags.TAGS_V2.keys())]
        tags_custom = [tag for tag in self.tiff.tag_v2.keys() if tag not in list(TiffTags.TAGS_V2.keys())]

        str_lst = [f"|{'Tag Nbr':^9s}|{'Predefined Tag Name':^31s}|{'Tag Value':^70s} |"]
        str_lst.append(f"|{'=' * 112} |")
        str_lst.extend([f"|{str(tag):^9s}| {TiffTags.TAGS_V2[tag].name:<30s}| {str(self.tiff.tag_v2.get(tag))[:70]:<70s}|" for tag in tags_predefined])
        str_lst.extend([f"|{str(tag):^9s}| {' ':^30s}| {str(self.tiff.tag_v2.get(tag))[:70]:<70s}|" for tag in tags_custom])
        print('\n'.join(str_lst))

    def show(self, frame_nbr: int = 0):
        """Display a frame as a thumbnail"""
        if self.is_valid_frame_(frame_nbr):
            print(f"Showing frame {frame_nbr} out of {self.n_frames}:")
            self.tiff.seek(frame_nbr)
            plt.figure(figsize=(2, 2))
            plt.imshow(self.tiff)
            plt.axis('off')
            plt.show()

    def show_all(self, n_max: Optional[int] = None):
        """Display all frames (up to n_max) as a grid of thumbnails"""
        if n_max is None:
            n_max = self.n_frames
            print(f"Showing all {self.n_frames} frames:")
        else:
            n_max = min(n_max, self.n_frames)
            print(f"Showing {n_max} first frames of {self.n_frames}:")

        ncols = 6
        nrows = n_max//ncols + 1
        plt.figure(figsize=(14, 2 * nrows))
        for i in range(n_max):
            self.tiff.seek(i)
            plt.subplot(nrows, ncols, i+1)
            plt.imshow(self.tiff, cmap='Greys_r')
            plt.axis('off')
        plt.show()

    def get_frame(self, frame_nbr: int = 0) -> TiffImageFile:
        """Return the frame specified by 'frame_nbr' as an object

        The frame object is returned as a 'PIL.TiffImagePlugin.TiffImageFile' object. Which inherits from
        'PIL.Image.Image' and can be handled as a normal PIL image.
        """
        if self.is_valid_frame_(frame_nbr):
            self.tiff.seek(frame_nbr)
            return self.tiff

    def extract_frames(self,
                       dest: Optional[PathStr] = None,
                       naming_method: str = 'counter',
                       tag: Optional[int] = None
                      ):

        """Extract each frame from the file and save them as individual TIFF image file

        All frames in the TIFF file are saved as independant TIFF image files.

        Two options for the individual image files:
         - 'counter':    original file + frame number suffix (frame number is 4-digit padded with zeros)
         - 'tag_value':  use the value in a specific tag, typycally when a class name is stored in a tag for each frame.
                         In case a `tag` value is not provided, the naming will revert to `counter`.
                         In case no tag is availaible in the frame for `tag`, the naming for that particular frame will
                         use the counter value instead for that frame.
        """
        if tag is None:
            # todo: add warning here
            naming_method = 'counter'

        dest = self.handle_pathstr_(dest)

        for count, frame in enumerate(ImageSequence.Iterator(self.tiff)):
            ifd = frame.ifd
            if naming_method == 'counter':
                frame_slug = f"{count:04d}"
            elif naming_method == 'tag_value':
                if tag in list(frame.tag_v2.keys()):
                    tag_value = frame.tag_v2.get(tag)
                    frame_slug = ''.join(c for c in tag_value if c in self.VALID_CHARS).replace(' ', '_')
                else:
                    frame_slug = f"{count:04d}"
            else:
                raise ValueError(f"'naming_method' should be 'counter' or 'tag_value'")

            tiffpath = dest/f"{self.path.stem}-{frame_slug}.tif"
            frame.save(tiffpath, tiffinfo=ifd)

    def extract_one_frame(self,
                          frame_nbr: int = 0,
                          image_format: str = 'jpg',
                          dest: Optional[PathStr] = None,
                          fname: Optional[str] = None,
                         ):
        """Extract the specified frame and save it as an image file of the specified format

        image_format: str   'tif' to save as a TIFF image, along with the frame IFD (tags)
                            'jpg' to save as a JPEG image, loosing the tag information
        dest: Path or str   destination directory where to save the image file
                            when not provided, the current working directory is selected by default
        fname: str          specific name for the image file
                            when not provided, the name will be the same as the original TIFF image file
                            name, with the frame number as sufix (4-digit padded with zeroes)

        """

        if image_format not in ['jpg', 'tif']:
            raise ValueError(f"'image_format' must be 'jpg' or 'tif', not '{image_format}'")

        dest = self.handle_pathstr_(dest)

        if self.is_valid_frame_(frame_nbr):
            self.tiff.seek(frame_nbr)
            ifd = self.tiff.ifd
            frame_slug = f"{frame_nbr:04d}"
            if fname is None:
                fname = dest/f"{self.path.stem}-{frame_slug}.{image_format}"
            else:
                fname = dest/f"{fname}.{image_format}"

            self.tiff.save(fname, tiffinfo=ifd)

    @property
    def n_frames(self):
        """Number of frame in TIFF image file"""
        return self.tiff.n_frames

    @property
    def tags(self):
        """Dictionary with all tag-tag value pairs in the current frame"""
        return dict(self.tiff.tag_v2.items())

    @property
    def size(self):
        """Size of the current frame in pixel"""
        return self.tiff.size

    def is_valid_frame_(self, frame_nbr: int):
        """Internal utility method to validate that the requested frame exist for the current TIFF file"""
        if frame_nbr > self.n_frames:
            raise ValueError(f"'frame_nbr' is {frame_nbr} but this TIFF file only has {self.n_frames} frames. ")
        else:
            return True

    @staticmethod
    def handle_pathstr_(pathstr: PathStr):
        """Internal utility method to handle pathlib.Path or string as path and validate that the path exists"""
        if pathstr is None:
            pathstr = Path('')
        else:
            pathstr = Path(pathstr)
            if not pathstr.exists():
                raise FileExistsError(f"Cannot find path {pathstr}")
        return pathstr