import os
import requests
import time
import base64
import json

from distutils.version import LooseVersion

from metaflow.plugins.aws.aws_client import get_signed_url, get_auth_object
from metaflow.exception import MetaflowException
from metaflow.metaflow_config import METADATA_SERVICE_NUM_RETRIES, METADATA_SERVICE_HEADERS, \
    METADATA_SERVICE_URL
from metaflow.metadata import MetadataProvider
from metaflow.metadata.heartbeat import HB_URL_KEY
from metaflow.sidecar import SidecarSubProcess
from metaflow.sidecar_messages import MessageTypes, Message

# Define message enums
class HeartbeatTypes(object):
    RUN = 1
    TASK = 2

class ServiceException(MetaflowException):
    headline = 'Metaflow service error'

    def __init__(self, msg, http_code=None, body=None):
        self.http_code = None if http_code is None else int(http_code)
        self.response = body
        super(ServiceException, self).__init__(msg)


class ServiceMetadataProvider(MetadataProvider):
    TYPE = 'service'

    def __init__(self, environment, flow, event_logger, monitor):
        super(ServiceMetadataProvider, self).__init__(environment, flow, event_logger, monitor)
        self.url_task_template = os.path.join(METADATA_SERVICE_URL,
                                             'flows/{flow_id}/runs/{run_number}/steps/{step_name}/tasks/{task_id}/heartbeat')
        self.url_run_template = os.path.join(METADATA_SERVICE_URL,
                                            'flows/{flow_id}/runs/{run_number}/heartbeat')
        self.sidecar_process = None

    @classmethod
    def compute_info(cls, val):
        v = val.rstrip('/')
        try:
            resp = requests.get(os.path.join(v, 'ping'), headers=METADATA_SERVICE_HEADERS)
            resp.raise_for_status()
        except:  # noqa E722
            raise ValueError('Metaflow service [%s] unreachable.' % v)
        return v

    @classmethod
    def default_info(cls):
        return METADATA_SERVICE_URL

    def version(self):
        return self._version(self._monitor)

    def new_run_id(self, tags=[], sys_tags=[]):
        return self._new_run(tags=tags, sys_tags=sys_tags)

    def register_run_id(self, run_id, tags=[], sys_tags=[]):
        try:
            # don't try to register an integer ID which was obtained
            # from the metadata service in the first place
            int(run_id)
            return
        except ValueError:
            return self._new_run(run_id, tags=tags, sys_tags=sys_tags)

    def new_task_id(self, run_id, step_name, tags=[], sys_tags=[]):
        return self._new_task(run_id, step_name, tags=tags, sys_tags=sys_tags)

    def register_task_id(self,
                         run_id,
                         step_name,
                         task_id,
                         tags=[],
                         sys_tags=[]):
        try:
            # don't try to register an integer ID which was obtained
            # from the metadata service in the first place
            int(task_id)
        except ValueError:
            self._new_task(run_id,
                           step_name,
                           task_id,
                           tags=tags,
                           sys_tags=sys_tags)
        finally:
            self._register_code_package_metadata(run_id, step_name, task_id)

    def _start_heartbeat(self, heartbeat_type, flow_id, run_id, step_name=None, task_id=None):
        if self._already_started():
            # A single ServiceMetadataProvider instance can not start
            # multiple heartbeat side cars of any type/combination. Either a
            # single run heartbeat or a single task heartbeat can be started
            raise Exception("heartbeat already started")
        # start sidecar
        if self.version() is None or \
                LooseVersion(self.version()) < LooseVersion('2.0.4'):
            # if old version of the service is running
            # then avoid running real heartbeat sidecar process
            self.sidecar_process = SidecarSubProcess("nullSidecarHeartbeat")
        else:
            self.sidecar_process = SidecarSubProcess("heartbeat")
        # create init message
        payload = {}
        if heartbeat_type == HeartbeatTypes.TASK:
            # create task heartbeat
            data = {
                    'flow_id': flow_id, 'run_number': run_id,
                    'step_name': step_name, 'task_id': task_id,
                    }
            payload[HB_URL_KEY] = self.url_task_template.format(**data)
        elif heartbeat_type == HeartbeatTypes.RUN:
            # create run heartbeat
            data = {'flow_id': flow_id, 'run_number': run_id}

            payload[HB_URL_KEY] = self.url_run_template.format(**data)
        else:
            raise Exception("invalid heartbeat type")
        payload["service_version"] = self.version()
        msg = Message(MessageTypes.LOG_EVENT, payload)
        self.sidecar_process.msg_handler(msg)

    def start_run_heartbeat(self, flow_id, run_id):
        self._start_heartbeat(HeartbeatTypes.RUN, flow_id, run_id)

    def start_task_heartbeat(self, flow_id, run_id, step_name, task_id):
        self._start_heartbeat(HeartbeatTypes.TASK,
                              flow_id,
                              run_id,
                              step_name,
                              task_id)

    def _already_started(self):
        return self.sidecar_process is not None

    def stop_heartbeat(self):
        msg = Message(MessageTypes.SHUTDOWN, None)
        self.sidecar_process.msg_handler(msg)

    def register_data_artifacts(self,
                                run_id,
                                step_name,
                                task_id,
                                attempt_id,
                                artifacts):
        url = ServiceMetadataProvider._obj_path(self._flow_name, run_id, step_name, task_id)
        url += '/artifact'
        data = self._artifacts_to_json(run_id, step_name, task_id, attempt_id, artifacts)
        self._request(self._monitor, url, data)

    def register_metadata(self, run_id, step_name, task_id, metadata):
        url = ServiceMetadataProvider._obj_path(self._flow_name, run_id, step_name, task_id)
        url += '/metadata'
        data = self._metadata_to_json(run_id, step_name, task_id, metadata)
        self._request(self._monitor, url, data)

    @classmethod
    def _get_object_internal(cls, obj_type, obj_order, sub_type, sub_order, filters=None, *args):
        # Special handling of self, artifact, and metadata
        if sub_type == 'self':
            url = ServiceMetadataProvider._obj_path(*args[:obj_order])
            try:
                return MetadataProvider._apply_filter([cls._request(None, url)], filters)[0]
            except ServiceException as ex:
                if ex.http_code == 404:
                    return None
                raise

        # For the other types, we locate all the objects we need to find and return them
        if obj_type != 'root':
            url = ServiceMetadataProvider._obj_path(*args[:obj_order])
        else:
            url = ''
        if sub_type != 'metadata':
            url += '/%ss' % sub_type
        else:
            url += '/metadata'
        try:
            return MetadataProvider._apply_filter(cls._request(None, url), filters)
        except ServiceException as ex:
            if ex.http_code == 404:
                return None
            raise

    def _new_run(self, run_id=None, tags=[], sys_tags=[]):
        # first ensure that the flow exists
        self._get_or_create('flow')
        run = self._get_or_create('run', run_id, tags=tags, sys_tags=sys_tags)
        return str(run['run_number'])

    def _new_task(self,
                  run_id,
                  step_name,
                  task_id=None,
                  tags=[],
                  sys_tags=[]):
        # first ensure that the step exists
        self._get_or_create('step', run_id, step_name)
        task = self._get_or_create('task', run_id, step_name, task_id, tags=tags, sys_tags=sys_tags)
        self._register_code_package_metadata(run_id, step_name, task['task_id'])
        return task['task_id']

    @staticmethod
    def _obj_path(
            flow_name, run_id=None, step_name=None, task_id=None, artifact_name=None):
        object_path = '/flows/%s' % flow_name
        if run_id:
            object_path += '/runs/%s' % run_id
        if step_name:
            object_path += '/steps/%s' % step_name
        if task_id:
            object_path += '/tasks/%s' % task_id
        if artifact_name:
            object_path += '/artifacts/%s' % artifact_name
        return object_path

    @staticmethod
    def _create_path(obj_type, flow_name, run_id=None, step_name=None):
        create_path = '/flows/%s' % flow_name
        if obj_type == 'flow':
            return create_path
        if obj_type == 'run':
            return create_path + '/run'
        create_path += '/runs/%s/steps/%s' % (run_id, step_name)
        if obj_type == 'step':
            return create_path + '/step'
        return create_path + '/task'

    def _get_or_create(
            self, obj_type, run_id=None, step_name=None, task_id=None, tags=[], sys_tags=[]):

        def create_object():
            data = self._object_to_json(
                obj_type,
                run_id,
                step_name,
                task_id,
                tags + self.sticky_tags,
                sys_tags + self.sticky_sys_tags)
            return self._request(self._monitor, create_path, data, obj_path)

        always_create = False
        obj_path = self._obj_path(self._flow_name, run_id, step_name, task_id)
        create_path = self._create_path(obj_type, self._flow_name, run_id, step_name)
        if obj_type == 'run' and run_id is None:
            always_create = True
        elif obj_type == 'task' and task_id is None:
            always_create = True

        if always_create:
            return create_object()

        try:
            return self._request(self._monitor, obj_path)
        except ServiceException as ex:
            if ex.http_code == 404:
                return create_object()
            else:
                raise


    @classmethod
    def _get_request_data(cls, http_function, url, headers, auth, data=None):
        kwargs = {
            "url": url,
            "headers": headers,
            "auth": auth
        }

        if data:
            kwargs['json'] = data
             
        data_responses = []
        resp = http_function(**kwargs)
        # Any error or other status code, return
        if resp.status_code > 300:
            return resp.text, resp.status_code

        # Add data to the list of data responses
        data_responses.append(resp.json()['data'])

        # LEK = {task_id={N=2520}}

        while 'LastEvaluatedKey' in resp.json():
            # task_id=N=2520
            base_last_evaluated_key = resp.json()['LastEvaluatedKey'].replace("}", "").replace("{", "")
            split_last_evaluated_key = base_last_evaluated_key.split("=")

            # Literally what is this, but it works. Python needs to add keys as variables for dict definition
            exclusive_start_key = {}
            exclusive_start_key[split_last_evaluated_key[0]] = {}
            exclusive_start_key[split_last_evaluated_key[0]][split_last_evaluated_key[1]] = split_last_evaluated_key[2]

            kwargs['params'] = { "last_evaluated":
                base64.b64encode(
                    json.dumps(
                        exclusive_start_key
                    ).encode()
            )}

            resp = http_function(**kwargs)
            data_responses.append(resp.json()['data'])

        if len(data_responses) == 1:
            return data_responses[0], resp.status_code
        else:
            # Break out the list of lists into a single list
            return [item for sublist in data_responses for item in sublist], resp.status_code


    @classmethod
    def _request(cls, monitor, path, data=None, retry_409_path=None):
        if cls.INFO is None:
            raise MetaflowException('Missing Metaflow Service URL. '
                'Specify with METAFLOW_SERVICE_URL environment variable')
        url = os.path.join(cls.INFO, path.lstrip('/'))

        auth = get_auth_object(url)
        body = None
        status_code = None
        for i in range(METADATA_SERVICE_NUM_RETRIES):
            try:
                if data is None:

                    if monitor:
                        with monitor.measure('metaflow.service_metadata.get'):
                            body, status_code = cls._get_request_data(requests.get, url, METADATA_SERVICE_HEADERS, auth)

                    else:
                        body, status_code = cls._get_request_data(requests.get, url, METADATA_SERVICE_HEADERS, auth)
                else:

                    if monitor:
                        with monitor.measure('metaflow.service_metadata.post'):
                            resp = requests.post(url, headers=METADATA_SERVICE_HEADERS, json=data, auth=auth)
                            
                            status_code = resp.status_code
                            if status_code < 300:
                                body = resp.json()
                            else:
                                body = resp.text

                    else:
                        resp = requests.post(url, headers=METADATA_SERVICE_HEADERS, json=data, auth=auth)

                        status_code = resp.status_code
                        if status_code < 300:
                            body = resp.json()
                        else:
                            body = resp.text
            except:  # noqa E722
                if monitor:
                    with monitor.count('metaflow.service_metadata.failed_request'):
                        if i == METADATA_SERVICE_NUM_RETRIES - 1:
                            raise
                else:
                    if i == METADATA_SERVICE_NUM_RETRIES - 1:
                        raise

            else:

                if status_code < 300:
                    return body
                elif status_code == 409 and data is not None:
                    # a special case: the post fails due to a conflict
                    # this could occur when we missed a success response
                    # from the first POST request but the request
                    # actually went though, so a subsequent POST
                    # returns 409 (conflict) or we end up with a
                    # conflict while running on AWS Step Functions
                    # instead of retrying the post we retry with a get since
                    # the record is guaranteed to exist
                    if retry_409_path:
                        return cls._request(monitor, retry_409_path)
                    else:
                        return
                elif status_code != 503:
                    raise ServiceException('Metadata request (%s) failed (code %s): %s'
                                           % (path, status_code, body),
                                           status_code,
                                           body)
            time.sleep(2**i)

        if body:
            raise ServiceException('Metadata request (%s) failed (code %s): %s'
                                   % (path, status_code, body),
                                    status_code,
                                    body)
        else:
            raise ServiceException('Metadata request (%s) failed' % path)

    @classmethod
    def _version(cls, monitor):
        pass
