import copy
import logging
import io
from timeit import default_timer
from abc import ABC
from . import s3
from typing import List, Dict, Union, Optional, Tuple
from pathlib import Path
from .utils import download
from . import ssh
import subprocess
import json
from PIL import Image
import numpy as np
import base64
import hashlib


class Uri:
    def __init__(self, uri: str) -> None:
        self.uri = uri

    def __repr__(self):
        return self.uri

    def __hash__(self) -> int:
        return hash(self.uri)

    def __eq__(self, __o: object) -> bool:
        return __o.uri == self.uri

    def __str__(self):
        return f"<Uri uri={self.uri}>"


def encode(url: str):
    return str(base64.urlsafe_b64encode(bytes(url, "utf-8")), 'utf-8')


def encode_short(url: str):
    e = hashlib.sha1(bytes(url, 'utf-8'))
    return e.hexdigest()


class Target(ABC):
    def exists(self):
        raise NotImplementedError()

    def delete(self):
        raise NotImplementedError()


class S3Target(s3.S3File):

    def delete(self):
        return super().unlink()

    @classmethod
    def from_uri(cls, uri: str):
        bucket, path = s3.S3.split_uri(uri)
        return S3Target(bucket, path)


class LocalTarget(Target):
    def __init__(self, path: Union[str, Path], *args):
        self.path = Path(path, *args).absolute()

    def exists(self):
        return self.path.exists()

    def open(self, mode: str, **args):
        return self.path.open(mode=mode, **args)

    def delete(self):
        return self.path.unlink(missing_ok=True)

    def read_text(self):
        with self.open("r") as reader:
            data = reader.read()

        return data

    def read_image(self, pil: bool = False):
        if pil:
            return Image.open(self.path)
        else:
            return np.asarray(Image.open(self.path))

    def read_json(self):
        return json.load(self.path.open('r'))


OutputType = Union[Target, List[Target], Dict[str, Target]]
Dependency = Union[Target, List[Target], "Task", List["Task"], Dict[str, Target], Dict[str, "Task"]]


def to_list(o: Optional[OutputType]):
    if o is None:
        return []

    elif isinstance(o, dict):
        return list(o.values())
    elif isinstance(o, list):
        return o
    elif isinstance(o, tuple):
        return list(o)
    else:
        return [o]


def depedendencies_resolved(deps: Dependency) -> bool:
    deps = to_list(deps)

    if len(deps) == 0:
        return True

    return all(o.exists() if isinstance(o, Target) else o.done() for o in deps)


class Task(ABC):
    def depends(self) -> Dependency:
        return []

    def run(self):
        """ run the task. write to the target """
        raise NotImplementedError(f"task {self.__class__.__name__} does not implement run() method")

    def target(self) -> OutputType:
        """ task must not need a target, but task will always be exectued if the target is not defined"""
        return None

    def done(self):
        return all(o.exists() for o in to_list(self.target()))

    def runnable(self) -> bool:
        return depedendencies_resolved(self.depends())

    def unresolved_dependencies(self):
        for dep in to_list(self.depends()):
            if isinstance(dep, Task):
                if not dep.done():
                    yield dep
            elif isinstance(dep, LocalTarget):
                if not dep.exists():
                    yield dep

    def delete(self, recursive: bool = False):
        for t in to_list(self.target()):
            t.delete()

        if recursive:
            for dep in to_list(self.depends()):
                if isinstance(dep, Task):
                    dep.delete(recursive=recursive)

    def _get_args(self):
        valid_items = filter(lambda x: not x[0].startswith("_"), self.__dict__.items())
        return ",".join(list(map(lambda x: F"{x[0]}={x[1]}", valid_items)))

    def __repr__(self) -> str:
        return f"<{self.__class__.__name__} {self._get_args()}>"

    def __hash__(self) -> int:
        return hash(self.__repr__())

    def __eq__(self, __o: "Task") -> bool:
        return hash(__o) == hash(self)

    def remote(self, ip: str, username: str, workdir: Union[str, Path], pem: Optional[Union[str, Path]] = None, executable: str = "python3"):
        client = ssh.get_client(ip, username, pem=pem)

        copy_task = copy.deepcopy(self)

        # sync dependencies
        for dep in to_list(copy_task.depends()):
            if isinstance(dep, Target):
                # sync local target to remote target
                targets = [dep]
            elif isinstance(dep, Task):
                targets = to_list(dep.target())
            else:
                targets = []

            for target in targets:
                if isinstance(target, LocalTarget):
                    if not target.exists():
                        raise Exception(f"could not find dependency {target}")

                    remote_target = LocalTarget(str(target.path.absolute()) + ".copy")
                    print(f"=> syncing target {target} to {remote_target}")
                    ssh.upload_file(client, target.path, remote_target.path)
                    target.path = remote_target.path

                elif isinstance(target, S3Target):
                    print(f"not syncing target {target}")

        # execute command python pickled task remotly

        # recreate environment
        # _, _, stderr = client.exec_command(f"ls {environment}")
        # environment_exists = len(stderr.readlines()) == 0

        # if not environment_exists:
        #     # create environment
        #     print("environment does not exist")

        # serialize object pickle
        path = f"/tmp/remote_task_{self.__class__.__name__}__{encode_short(self._get_args())}.pkl"
        remote_path = f"/tmp/task_{self.__class__.__name__}__{encode_short(copy_task._get_args())}.pkl"
        remote_result_path = remote_path + '.result'

        import pickle
        import os

        with open(path, 'wb') as writer:
            pickle.dump(copy_task, writer)

        assert (os.path.getsize(path) > 0)

        # upload serialized payload
        ssh.upload_file(client, path, remote_path)

        command = f"cd {workdir} && {executable} -c \"import pickle as p; task = p.load(open('{remote_path}', 'rb')); r=task.run(); p.dump(r,open('{remote_result_path}', 'wb'))\""
        print("execute command: ", command)

        _, stdout, stderr = client.exec_command(command, get_pty=True, environment=os.environ)

        stderr = stderr.readlines()
        stdout = stdout.readlines()

        if len(stderr) > 0:
            raise Exception(f"failed to execute task remotely. stderr: {stderr}. {stdout}")

        # copy result from remote to local machine
        print("copy targts from remote to local")

        for target in to_list(self.target()):
            if isinstance(target, LocalTarget):
                local_path = target.path
                # local_path = str(target.path.absolute()) + ".remote"
                print(f"download file {local_path}")
                ssh.download_file(client, target.path, local_path)

        # copy and read result
        local_result_path = remote_result_path + '.local'
        ssh.download_file(client, remote_result_path, local_result_path)

        return pickle.load(open(local_result_path, 'rb'))


class DepTask(Task):
    def run(self):
        pass

    def target(self) -> OutputType:
        return [dep.target() for dep in to_list(self.depends())]


class DownloadTask(Task):
    def __init__(self, url: str, destination: Path, auth: Optional[Tuple[str, str]] = None, headers: Optional[Dict[str, str]] = None) -> None:
        self.url: str = url
        self.destination: Path = Path(destination)
        self.headers = headers
        self.auth = auth

    def run(self):
        download(self.url, str(self.destination.absolute()), auth=self.auth, headers=self.headers)

    def target(self) -> LocalTarget:
        return LocalTarget(self.destination)


class TempDownloadTask(Task):
    def __init__(self, url: str, auth: Optional[Tuple[str, str]] = None, headers: Optional[Dict[str, str]] = None, suffix: Optional[str] = None) -> None:
        self.url: str = url

        filename = str(get_hash(url))
        if suffix:
            filename += suffix

        self.destination: Path = Path("/tmp/", filename).absolute()
        self.headers = headers
        self.auth = auth

    def run(self):
        print(f"downloading {self.url} to {self.destination}")
        start = default_timer()
        download(self.url, str(self.destination.absolute()), auth=self.auth, headers=self.headers)
        return dict(elapsed=default_timer() - start)

    def target(self) -> LocalTarget:
        return LocalTarget(self.destination)


def get_hash(obj: any) -> str:
    return encode_short(json.dumps(obj))


class BashTask(Task):
    def __init__(self, cmd: List[str]) -> None:
        self.cmd = cmd

    def run(self):
        print(f"[Bash] executing {self.cmd}")
        result = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

        stderr = result.stderr.readlines()
        stdout = result.stdout.readlines()

        if (result.returncode is not None and result.returncode != 0) or len(stderr) > 0:
            raise Exception(f"task {self.__repr__()} [code={result.returncode}] has failed: {stderr}")

        print('stdout: ', stdout)
        print('stderr: ', stderr)

        with self.target().open('wb') as writer:
            writer.writelines(stdout)

        return stdout

    def target(self):
        return LocalTarget(f"/tmp/task_bash_{get_hash(self.cmd)}.output")


class StringOutputTaskBase(Task):

    def _run(self) -> Union[str, List[str]]:
        pass

    def run(self):
        output = self._run()

        with self.target().open("w") as writer:
            if isinstance(output, str):
                writer.write(output)
            else:
                writer.writelines(output)

        return output

    def target(self):
        args = self._get_args()
        return LocalTarget(f"/tmp/{self.__class__.__name__}_{encode_short(args)}.output")


class SSHCommandTask(Task):

    def __init__(self, ip: str, username: str, cmd: List[str], pem: Union[str, Path] = None) -> None:
        self.cmd = cmd
        self.pem = pem
        self.ip = ip
        self.username = username

    def run(self):

        logging.info(f"=> executing {self.cmd}")

        client = ssh.get_client(self.ip, self.username, pem=self.pem)
        stdin, stdout, stderr = client.exec_command(" ".join(self.cmd))

        stderr = stderr.readlines()

        if len(stderr) > 0:
            raise Exception(f"failed to execute {self.cmd} - stderr: {stderr}")

        out = stdout.readlines()

        with self.target().open('w') as writer:
            writer.writelines(out)

        return out

    def target(self):
        return LocalTarget(f"/tmp/task_ssh_{self.ip}_{self.username}_{get_hash(self.cmd)}.output")


class S3UploadTask(Task):

    def __init__(self, local_path: Union[str, Path], target_uri: Union[str, Tuple[str, str]]) -> None:
        self.local_path = Path(local_path)
        self.target_uri = target_uri if isinstance(target_uri, str) else f"s3://{target_uri[0]}/{target_uri[1]}"

    def depends(self) -> Dependency:
        return LocalTarget(self.local_path)

    def run(self):
        assert (self.target().upload(self.local_path))

    def target(self):
        return S3Target.from_uri(self.target_uri)


class S3DownloadTask(Task):

    def __init__(self, uri: Union[str, Tuple[str, str]], local_path: Union[str, Path]) -> None:
        self.local_path = Path(local_path)
        self.uri = uri if isinstance(uri, str) else f"s3://{uri[0]}/{uri[1]}"

    def depends(self):
        return S3Target.from_uri(self.uri)

    def run(self):
        self.depends().download(self.local_path)

    def target(self):
        return LocalTarget(str(self.local_path.absolute()))
