import json
import logging
import pickle
import subprocess
from time import sleep
from functools import reduce

import requests
import os

logging.basicConfig(level=logging.DEBUG)

logger = logging.getLogger("dnnevo")


class DatasetDescriptor(object):

    def __init__(
            self,
            dataset_name,
            input_shape,
            output_shape,
            samples_count,
    ):
        self.__dataset_name = dataset_name
        self.__input_shape = input_shape
        self.__output_shape = output_shape
        self.__samples_count = samples_count
        self.__input_size = reduce(lambda l, r: l * r, input_shape)
        self.__output_size = reduce(lambda l, r: l * r, output_shape)

    @property
    def dataset_name(self):
        return self.__dataset_name

    @property
    def input_shape(self):
        return self.__input_shape

    @property
    def output_shape(self):
        return self.__output_shape

    @property
    def samples_count(self):
        return self.__samples_count

    @property
    def input_size(self):
        return self.__input_size

    @property
    def output_size(self):
        return self.__output_size


class DataStoreClient(object):
    def __init__(self, dataset_name, indices=None, batch_size=1, is_train=True, dir="./", is_local=True):
        if indices is None or indices == []:
            self.indices = list(range(10000))
        else:
            self.indices = indices
        self.dataset_name = dataset_name.lower()
        self.batch_size = batch_size
        self.prefix = self.dataset_name + "_"
        self.suffix = "_" + ("train" if is_train else "test")
        self.dir = dir
        self.is_local = is_local

    def format_id(self, id):
        if len(id.split("_")) > 1:
            id = id.split("_")[1]
        return "{}{}{}".format(self.prefix, id, self.suffix)

    def _chunks(self, l, n):
        """Yield successive n-sized chunks from l."""
        for i in range(0, len(l), n):
            yield l[i:i + n]

    def _url(self, url):
        return ("http://" + url.split(":")[0] if not self.is_local else "http://localhost") + ":" + (url.split(":")[
                                                                                                         1] if self.is_local else "5000")

    def get_by_id(self, id):
        id = str(id)
        id = self.format_id(id)
        id_mapping = requests.get('{}/get/'.format(self._url("master:5001")) + id).json()
        response = requests.get(self._url(id_mapping["worker"]) + "/get/" + id)
        try:
            return pickle.loads(response.content)
        except IOError:
            logger.error("Failed saving result")
            return ""

    def get_by_ids(self, ids):
        paths = []
        ids = list(map(lambda x: self.format_id(x), ids))
        for chunk in self._chunks(ids, self.batch_size):
            response = requests.post('{}/batch/get'.format(self._url("master:5001")),
                                     json.dumps({"ids": chunk}),
                                     headers={'content-type': 'Application/json'})
            batches = response.json()
            for batch in batches:
                url = self._url(batch["addr"]) + "/batch/get/" + batch["id"]
                try:
                    response = requests.get(url)
                    while response.status_code == 202:
                        response = requests.get(url)
                        sleep(5)
                    path = self.dir
                    with open(path + batch["id"] + ".tar.gz", "wb+") as file:
                        file.write(response.content)
                    subprocess.run(["tar", "-zxvf", path + batch["id"] + ".tar.gz"])
                    subprocess.run(["rm", path + batch["id"] + ".tar.gz"])
                    paths.append(path + batch["id"])
                except IOError as ex:
                    logger.error("Failed saving result: {}".format(ex))
        data = []
        target = []
        for path in paths:
            files = [f for f in os.listdir(path)]
            for file in files:
                with open(path + "/" + file, "rb") as f:
                    (dt, trgt) = pickle.load(f)
                    data.append(dt)
                    target.append(trgt)
            subprocess.run(["rm", "-r", path])
        return data, target

    def dataset_size(self):
        response = requests.get('{}/size/{}'.format(self._url("master:5001"), self.dataset_name, self.suffix))
        return response.json()["size"]

    def get_descriptor(self):
        response = requests.get('{}/descriptor/{}'.format(self._url("master:5001"), self.dataset_name))
        jsn = response.json()
        return DatasetDescriptor(
            self.dataset_name,
            tuple(list(map(lambda x: int(x), jsn["input_size"]))),
            tuple(list(map(lambda x: int(x), jsn["output_size"]))),
            int(jsn["size"])
        )
