# AUTOGENERATED! DO NOT EDIT! File to edit: ../16_CutMixRICAP.ipynb.

# %% auto 0
__all__ = ['CutMixRICAP']

# %% ../16_CutMixRICAP.ipynb 2
from .holemakertechnique import *
from .holemakerpoint import *
from .holesfilling import *
import numpy as np
import random

# %% ../16_CutMixRICAP.ipynb 4
class CutMixRICAP(HolesFilling):
    "Defines the amount of holes, the technique used to make them and the probability of apply the technique."
    def __init__(self,
                 holes_num = 1, # The amount of holes to make.
                 hole_maker: "HoleMakerTechnique" = None, # The strategy used to make the holes.
                 p = 1.0): # The probability of applying this technique.
        hole_maker = hole_maker if hole_maker else HoleMakerPoint()
        super().__init__(hole_maker)
        self.p = p
  
    def before_batch(self):
        "Applies the CutMixRICAP technique (divides the image into a grid and shuffles the portions)."
        if random.random() < self.p:
            image_pieces = []
            mask_pieces = []
            holes = []
            for image, mask in zip(self.x, self.y):
                shape = image.shape[1:]
                self.hole_maker.hole_size = (int(shape[1] / 2), int(shape[0] / 2))
                for randy in range(0, shape[0], self.hole_maker.hole_size[1]):
                    for randx in range(0, shape[1], self.hole_maker.hole_size[0]):
                        self.hole_maker.x = randx
                        self.hole_maker.y = randy
                        xhole, yhole = self.make_hole(mask)
                        image_pieces.append(image[:, yhole, xhole])
                        mask_pieces.append(mask[yhole, xhole])
                        holes.append([xhole, yhole])

            for image, mask in zip(self.x, self.y):
                for _ in range(4):
                    xhole, yhole = holes.pop()
                    rand = random.randint(0, len(image_pieces) - 1) if len(image_pieces) - 1 >= 0 else 0
                    sub_image, sub_mask = image_pieces.pop(rand), mask_pieces.pop(rand)
                    self.fill_hole(image, mask, xhole, yhole, [sub_image, sub_mask])
