# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/03_utils.ipynb (unless otherwise specified).

__all__ = ['get_name', 'find_contours', 'download_file_from_google_drive', 'mkdir', 'put_text', 'images_to_video',
           'av_i2v', 'video_to_images', 'av_v2i', 'TimeLoger', 'identify', 'memoize']

# Cell
import os
from glob import glob
import cv2
import os.path as osp
from tqdm import tqdm
import mmcv
from fastcore.script import call_parse, Param
from .process import multi_thread
from loguru import logger

def get_name(path):
    path = osp.basename(path).split('.')[:-1]
    return '.'.join(path)


def find_contours(thresh):
    """
        Get contour of a binary image
            Arguments:
                thresh: binary image
            Returns:
                Contours: a list of contour
                Hierarchy:

    """
    try:
        contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE,
                                               cv2.CHAIN_APPROX_SIMPLE)
        return contours, hierarchy[0]
    except:
        return None, None


@call_parse
def download_file_from_google_drive(id_or_link: Param("Link or file id"), destination: Param("Path to the save file")):
    if "https" in id_or_link:
        x = id_or_link
        id = x.split("/")[x.split("/").index("d")+1]
    else:
        id = id_or_link
    logger.info(f"Download from id: {id}")
    import requests

    def get_confirm_token(response):
        for key, value in response.cookies.items():
            if key.startswith('download_warning'):
                return value

        return None

    def save_response_content(response, destination):
        CHUNK_SIZE = 32768

        with open(destination, "wb") as f:
            for chunk in response.iter_content(CHUNK_SIZE):
                if chunk:  # filter out keep-alive new chunks
                    f.write(chunk)

    URL = "https://docs.google.com/uc?export=download"

    session = requests.Session()

    response = session.get(URL, params={'id': id}, stream=True)
    token = get_confirm_token(response)

    if token:
        params = {'id': id, 'confirm': token}
        response = session.get(URL, params=params, stream=True)

    save_response_content(response, destination)
    logger.info(f"Done -> {destination}")
    return osp.abspath(destination)


def mkdir(path):
    os.makedirs(path, exist_ok=True)


def put_text(image, pos, text, color=(255, 255, 255)):
    return cv2.putText(image, text, pos, cv2.FONT_HERSHEY_SIMPLEX, 1.0,
                       color, 2)


def images_to_video(
        images,
        out_path=None,
        fps: int = 30,
        no_sort=False,
        max_num_frame=10e12,
        resize_rate=1,
        with_text=False,
        text_is_date=False,
        verbose=True,
):

    if out_path is None:
        assert isinstance(
            images, str), "No out_path specify, you need to input a string to a directory"
        out_path = images+'.mp4'
    if isinstance(images, str) and os.path.isdir(images):
        from glob import glob
        images = glob(os.path.join(images, "*.jpg")) + \
            glob(os.path.join(images, "*.png")) + \
            glob(os.path.join(images, "*.jpeg"))

    def get_num(s):
        try:
            s = os.path.basename(s)
            num = int(''.join([c for c in s if c.isdigit()]))
        except:
            num = s
        return num
#     global f

    def f(img_or_path):
        if isinstance(img_or_path, str):
            name = os.path.basename(img_or_path)
            img = mmcv.imread(img_or_path)
            img = cv2.resize(img, output_size)
            assert img is not None, img_or_path
            if with_text:
                if text_is_date:
                    from datetime import datetime
                    name = name.split('.')[0].split('_')
                    f = float('{}.{}'.format(*name))
                    name = str(datetime.fromtimestamp(f))
                img = put_text(img, (20, 20), name)
        else:
            img = img_or_path
        return img

    if not no_sort and isinstance(images[0], str):
        images = list(sorted(images, key=get_num))

    max_num_frame = int(max_num_frame)
    max_num_frame = min(len(images), max_num_frame)

    h, w = mmcv.imread(images[0]).shape[:2]
    output_size = (int(w*resize_rate), int(h*resize_rate))
    if out_path.endswith('.mp4'):
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(out_path, fourcc, fps, output_size)
    elif out_path.endswith('.avi'):
        out = cv2.VideoWriter(
            out_path, cv2.VideoWriter_fourcc(*'DIVX'), fps, output_size)
    else:
        raise NotImplementedError
    images = images[:max_num_frame]
    images = multi_thread(f, images, verbose=verbose)
    if verbose:
        logger.info("fWrite video, output_size: {output_size}")
        pbar = mmcv.ProgressBar(len(images))
    logger.info(f"out_path: {out_path}")
    for img in images:
        img = cv2.resize(img, output_size)
        out.write(img)
        if verbose:
            pbar.update()
    out.release()

# Cell
@call_parse
def av_i2v(
            images: Param("Path to the images folder or list of images"),
            out_path: Param("Output output video path", str)=None,
            fps: Param("Frame per second", int) = 30,
            no_sort: Param("Sort images", bool) = False,
            max_num_frame: Param("Max num of frame", int) = 10e12,
            resize_rate: Param("Resize rate", float) = 1,
            with_text: Param("Add additional index to image when writing vidoe", bool) = False,
            text_is_date: Param("Add additional index to image when writing vidoe", bool) = False,
            verbose:Param("Print...", bool)=True,
        ):

    return images_to_video(images, out_path, fps,
                           no_sort, max_num_frame, resize_rate, with_text,
                           text_is_date,verbose,
            )


# Cell
def video_to_images(input_video, output_dir=None, skip=1):
    """
        Extract video to image:
            inputs:
                input_video: path to video
                output_dir: default is set to video name
    """

    if output_dir is None:
        vname = get_name(input_video).split('.')[0]
        output_dir = osp.join('.cache/video_to_images/', vname)
        logger.info(f'Set output_dir = {output_dir}')

    video = mmcv.video.VideoReader(input_video)
    pbar = mmcv.ProgressBar(video._frame_cnt)
    for i in range(0, len(video), skip):
        try:
            img = video[i]
            out_img_path = os.path.join(output_dir, f'{i:05d}' + '.jpg')
            mmcv.imwrite(img, out_img_path)
            pbar.update()
        except Exception as e:
            logger.warning(f"Cannot write image index {i}, exception: {e}")
            continue

@call_parse
def av_v2i(input_video:Param("", str), output_dir:Param("", str)=None, skip:Param("", int)=1):
    return video_to_images(input_video, output_dir, skip)

# Cell
from mmcv import Timer
import pandas as pd
import numpy as np

class TimeLoger:
    def __init__(self):
        self.timer = Timer()
        self.time_dict = dict()

    def start(self):
        self.timer.start()

    def update(self, name):
        # assert not name in self.time_dict
        duration = self.timer.since_last_check()
        if name in self.time_dict:
            self.time_dict[name].append(duration)
        else:
            self.time_dict[name] = [duration]

    def __str__(self):
        total_time = np.sum([np.sum(v) for v in self.time_dict.values()])
        s = f"------------------Time Loger Summary : Total {total_time:0.2f} ---------------------:\n"
        for k, v in self.time_dict.items():
            average = np.mean(v)
            times = len(v)
            percent = np.sum(v)*100/total_time
            s += f'\t\t{k}:  \t\t{percent:0.2f}% ({average:0.4f}s) | Times: {times} \n'
        return s

# Cell
import xxhash
import pickle

def identify(x):
    '''Return an hex digest of the input'''
    return xxhash.xxh64(pickle.dumps(x), seed=0).hexdigest()


def memoize(func):
    import os
    import pickle
    from functools import wraps
    import xxhash
    '''Cache result of function call on disk
    Support multiple positional and keyword arguments'''
    @wraps(func)
    def memoized_func(*args, **kwargs):
        cache_dir = '.cache'
        try:
            import inspect
            func_id = identify((inspect.getsource(func), args, kwargs))
            cache_path = os.path.join(cache_dir, func.__name__+'_'+func_id)

            if (os.path.exists(cache_path) and
                    not func.__name__ in os.environ and
                    not 'BUST_CACHE' in os.environ):
                result = pickle.load(open(cache_path, 'rb'))
            else:
                result = func(*args, **kwargs)
                os.makedirs(cache_dir, exist_ok=True)
                pickle.dump(result, open(cache_path, 'wb'))
            return result
        except (KeyError, AttributeError, TypeError, Exception) as e:
            logger.warning(f'Exception: {e}, use default function call')
            return func(*args, **kwargs)
    return memoized_func