# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/20_Image_Labeler.ipynb (unless otherwise specified).

__all__ = ['ImageLabeler', 'SingleClassImageLabeler', 'MultiClassImageLabeler']

# Cell
import json
from pathlib import Path
from forgebox.files import file_detail
from forgebox.html import DOM
import pandas as pd
from typing import List, Dict
from ipywidgets import interact, interact_manual, Button, SelectMultiple, \
    Output, HBox, VBox
from PIL import Image as PILImage
from ipywidgets import Image as ImageWidget
import logging
from tqdm.notebook import tqdm

# Cell
class ImageLabeler:
    def __init__(self,
                 image_folder: Path,
                 formats: List[str] = ["jpg", "jpeg", "png", "bmp"],
                 ):
        """
        path: Path, a folder full of images
        formats: a list of allowed formats
        """
        self.image_folder = image_folder
        self.file_df = file_detail(image_folder)
        self.filter_image(formats)
        self.output = Output()

    def __repr__(self):
        return f"{self.__class__.__name__} on [{self.image_folder}({len(self.image_df)})], see labeler.image_df"

    def filter_image(
        self,
        formats: List[str] = ["jpg", "jpeg", "png", "bmp"]
    ) -> pd.DataFrame:
        """
        Filter the file dataframe to image only files
        assign image_df attribute to the object
        """
        formats += list(map(lambda x: x.upper(), formats))
        self.image_df = self.file_df[self.file_df.file_type.isin(
            formats)].reset_index(drop=True)
        return self.image_df

    def __call__(self, *args, **kwargs):
        raise NotImplementedError(
            f"Please use SingleClassImageLabeler, or MultiClassImageLabeler")

    @property
    def identifier(self):
        return self.progress['meta']['identifier']

    def save_progress(
        self,
        location: Path = Path("."),
        filename="unpackai_imglbl.json"
    ):
        """
        Save the progress to location/filename
        default save to current directory ./unpackai_imglbl.json
        """
        with open(location/filename, "w") as f:
            f.write(json.dumps(self.progress))
        logging.info(f"Progess Saved to {location/filename}")

    @classmethod
    def load_saved(cls, filepath="./unpackai_imglbl.json"):
        """
        Load saved labeler's progress
        """
        with open(filepath, "r") as f:
            progress = json.loads(f.read())
        image_folder = progress['meta']['image_folder']
        obj = cls(image_folder)
        obj.progress = progress
        return obj

    def new_progress(self, labels: List[str], identifier: str = "path"):
        self.progress = dict(
            meta=dict(
                image_folder=self.image_folder,
                labels=labels,
                identifier=identifier,
            ),
            data=dict((str(k), None) for k in list(self.image_df[identifier]))
        )

    def __call__(self, labels: List[str] = ["pos", "neg"]):
        self.labels = labels
        if hasattr(self, "progress") == False:
            self.new_progress(labels)

        for k, v in tqdm(self.progress['data'].items(), leave=False):
            if v is None:
                yield k

    def __getitem__(self, key):
        """
        render a page according to key
        """
        row = self.get_row_data(key)
        self.output.clear_output()
        with self.output:
            with PILImage.open(
                    row[self.identifier]).resize((512, 512)) as img:
                display(img)
            label_btns = self.create_label_btns(row)
            key = row[self.identifier]

            # current labeled label
            current = self.progress['data'][key]
            if current is not None:
                DOM(f"Current Label:{current}", "h5")()

            # navigation buttons
            nav_btns = list(btn for btn in [self.create_show_last_btn(key),
                                            self.create_show_next_btn(key),
                                            self.create_save_btn(),
                                            self.create_save_to_csv(),
                                           ] if btn is not None)
            display(VBox([label_btns,
                          HBox(nav_btns)
                          ]))

    def get_row_data(self, key):
        identifier = self.identifier
        row = dict(self.image_df.query(
            f"{identifier}=='{key}'").to_dict(orient='records')[0])
        return row

    def render_page(self):
        """
        Render a new page
        """
        try:
            key = next(self.gen)
        except StopIteration:
            self.save_progress()
            self.done_page()
            return
        self[key]

    def create_show_last_btn(self, key):
        keys = list(self.progress["data"].keys())
        idx = keys.index(str(key))
        if idx == 0:
            return None
        last_key = keys[idx-1]

        def show_last_click():
            self[last_key]
        btn = Button(description="Last", icon="arrow-left")
        btn.click = show_last_click
        return btn

    def create_show_next_btn(self, key):
        keys = list(self.progress["data"].keys())
        idx = keys.index(str(key))
        if idx >= len(self.progress["data"])-1:
            return None
        next_key = keys[idx+1]

        def show_next_click():
            self[next_key]
        btn = Button(description="Next", icon="arrow-right")
        btn.click = show_next_click
        return btn

    def create_save_btn(self):
        btn = Button(description="Save JSON", icon='save')
        btn.click = self.save_progress
        return btn

    def save_to_csv(self):
        DOM("Please name a filepath for csv file like ./progress.csv","div")()

        @interact_manual
        def save_csv(path = "./progress.csv"):
            if len(self.progress['data'])==0:
                DOM("Nothing to save","div")()
            keys, vals = zip(*list(
                (k, v) for k, v in self.progress["data"].items() if v is not None))
            pd.DataFrame({"path":keys, "label":vals}).to_csv(path, index=False)
            DOM(f"Progress saved to: '{path}'","div")()

    def create_save_to_csv(self):
        btn = Button(description="CSV", icon='save')
        btn.click = self.save_to_csv
        return btn

class SingleClassImageLabeler(ImageLabeler):
    def __init__(self, image_folder: Path):
        """
        path: Path, a folder full of images
        """
        super().__init__(image_folder)

    def __call__(self, labels: List[str] = ["pos", "neg"]):
        self.gen = super().__call__(labels)

        self.render_page()

        display(self.output)

    def create_label_btns(self, row):
        btns = []
        for label in self.labels:
            btn = Button(description=label, icon="check-circle")

            def callback():
                k = row[self.identifier]
                self.progress["data"][str(k)] = label
                self.render_page()
            btn.click = callback
            btns.append(btn)

        return HBox(btns)

    def done_page(self):
        self.output.clear_output()
        with self.output:
            DOM("That's the end of the iteration", "h3")()


class MultiClassImageLabeler(ImageLabeler):
    def __init__(self, image_folder: Path):
        """
        path: Path, a folder full of images
        """
        super().__init__(image_folder)

    def __call__(self, labels: List[str] = ["pos", "neg"]):
        self.gen = super().__call__(labels)
        DOM("press Command(mac) or Ctrl(win/linux) to select multiple","h4")()
        self.render_page()
        display(self.output)

    def create_label_btns(self, row):
        btns = []
        select = SelectMultiple(options=self.labels)
        btn = Button(description="Okay!", icon="check-circle")

        def callback():
            k = row[self.identifier]
            self.progress["data"][str(k)] = list(select.value)
            self.render_page()

        btn.click = callback

        return HBox([select, btn])

    def done_page(self):
        self.output.clear_output()
        with self.output:
            DOM("That's the end of the iteration", "h3")()