from random import random

import torch
from torch.utils.data import DataLoader


class MultiLoader:

    def __init__(self, datasets: list, batch_size: Union[int, list]):
        '''
        Extension of the PyTorch dataloader. The main feature is the ability
        to create the returned minibatches by sampling from different datasets.
        The iterator stops returning elements when each of them was returned at least once.
        e.g. If dataset A has 1000 elements, dataset B has 100 and we specify batch_size=10
             each mini batch will contain 5 elements from dataset A and 5 from dataset B.
             A total of 2000 elements will be returned, the elements from dataset A will be
             returned just once, the elements from dataset B will be returned multiple times.
             In this case, as long as the batch_size <= 200 there won't be repated elements
             inside the mini batch.

        :param datasets: list of datasets used to create the minibatches.
        :param batch_size: if it's an int batches of the specified size
            are returned. The returned batches are composed by sampling
            from each dataset batch_size / len(datasets) elements.
            If batch_size is a list it can be used to specify how many
            elements to sample from each dataset.
        '''
        self.datasets = []
        self.no_datasets = len(datasets)
        self.no_steps = 0
        self.actual_steps = 0

        if type(batch_size) == int:
            b = batch_size
            batch_size = [int(b / self.no_datasets) for x in range(self.no_datasets)]

        for i, ds in enumerate(datasets):
            dl = DataLoader(ds, batch_size=batch_size[i], shuffle=True, pin_memory=True)
            self.no_steps = len(dl) if len(dl) > self.no_steps else self.no_steps
            self.datasets.append(dl)

    def __next__(self):
        if self.actual_steps == self.no_steps:
            raise StopIteration

        batch_in = batch_out = None

        for i in range(len(self.iters)):
            try:
                x, y = next(self.iters[i])

            except StopIteration:
                self.iters[i] = self.datasets[i].__iter__()
                x, y = next(self.iters[i])
            batch_in = torch.cat((batch_in, x)) if batch_in != None else x
            batch_out = torch.cat((batch_out, y)) if batch_out != None else y

        self.actual_steps += 1
        return batch_in, batch_out

    def __iter__(self):
        self.iters = []

        for ds in self.datasets:
            self.iters.append(ds.__iter__())

        return self

    def __len__(self):
        return self.no_steps


class Buffer:

    def __init__(self, ds, dim):
        l = len(ds)
        r = []

        for i in range(dim):
            r.append(ds[i])

        for i in range(dim, l):
            h = random.randint(0, i)
            if h < dim:
                r[h] = ds[i]
        self.r = r

    def __getitem__(self, item):
        return self.r[item]

    def __len__(self):
        return len(self.r)

    def add(self, buffer, l):
        b = list(buffer)

        for i in range(l):
            self.r = self.r + b
