import boto3
import logging
from time import sleep
from quickbe import Log
from autom8it import AutomationTask

for log_name in ['boto', 'boto3', 'botocore', 's3transfer', 'urllib3']:
    logging.getLogger(log_name).setLevel(logging.WARNING)


class DeregisterTarget(AutomationTask):

    TARGET_GROUP_ARN_KEY = 'target_group_arn'
    TARGET_ID_KEY = 'target_id'
    PORT_KEY = 'port'

    aws_client = boto3.client('elbv2')

    def __init__(self, task_data: dict):
        super().__init__(task_data=task_data, check_done_interval=10)

    @property
    def validation_schema(self) -> dict:
        return {
            self.TARGET_GROUP_ARN_KEY: {
                'required': True,
                'type': 'string'
            },
            self.TARGET_ID_KEY: {
                'required': True,
                'type': 'string'
            },
            self.PORT_KEY: {
                'required': True,
                'type': 'integer'
            }
        }

    @property
    def task_type(self) -> str:
        return 'Deregister target'

    def do(self):
        target_group_desc = self.aws_client.deregister_targets(
            TargetGroupArn=self.get_task_attribute(self.TARGET_GROUP_ARN_KEY),
            Targets=[{
                'Id': self.get_task_attribute(self.TARGET_ID_KEY),
                'Port': self.get_task_attribute(self.PORT_KEY)
            }],
        )
        return target_group_desc

    def is_done(self) -> bool:
        target_id = self.get_task_attribute(self.TARGET_ID_KEY)
        target_group_desc = self.aws_client.describe_target_health(
            TargetGroupArn=self.get_task_attribute(self.TARGET_GROUP_ARN_KEY)
        )
        for target in target_group_desc['TargetHealthDescriptions']:
            if target['Target']['Id'] == target_id:
                Log.debug(f"Target {target_id} state: {target['TargetHealth']['State']}")
                return False

        return True


class RegisterTarget(DeregisterTarget):

    def __init__(self, task_data: dict):
        super().__init__(task_data=task_data)

    @property
    def task_type(self) -> str:
        return 'Register target'

    def do(self):
        target_group_desc = self.aws_client.register_targets(
            TargetGroupArn=self.get_task_attribute(self.TARGET_GROUP_ARN_KEY),
            Targets=[{
                'Id': self.get_task_attribute(self.TARGET_ID_KEY),
                'Port': self.get_task_attribute(self.PORT_KEY, 80)
            }],
        )
        return target_group_desc

    def is_done(self) -> bool:
        target_id = self.get_task_attribute(self.TARGET_ID_KEY)
        target_group_desc = self.aws_client.describe_target_health(
            TargetGroupArn=self.get_task_attribute(self.TARGET_GROUP_ARN_KEY)
        )
        for target in target_group_desc['TargetHealthDescriptions']:
            if target['Target']['Id'] == target_id:
                target_state = target['TargetHealth']['State']
                Log.debug(f"Target {target_id} state: {target_state}")
                if target_state in ['healthy']:
                    return True

        return False


class StopEC2Instance(AutomationTask):

    INSTANCE_ID_KEY = 'instance_id'

    @property
    def validation_schema(self) -> dict:
        return {
            self.INSTANCE_ID_KEY: {
                'required': True,
                'type': 'string'
            },
        }

    aws_client = boto3.client('ec2')

    def __init__(self, task_data: dict):
        super().__init__(task_data=task_data)

    @property
    def task_type(self) -> str:
        return 'Stop EC2 instance'

    def do(self):
        return self.aws_client.stop_instances(
            InstanceIds=[self.get_task_attribute(self.INSTANCE_ID_KEY)],
        )

    def is_done(self) -> bool:
        return self.is_state_equals(requested_state='stopped')

    def is_state_equals(self, requested_state: str) -> bool:
        instance_id = self.get_task_attribute(self.INSTANCE_ID_KEY)
        desc = self.aws_client.describe_instance_status(
            InstanceIds=[instance_id],
        )
        for data in desc['InstanceStatuses']:
            if data['InstanceId'] == instance_id:
                state = data['InstanceState']['Name']
                Log.debug(f"Instance {instance_id} state: {state}")
                if state in [requested_state]:
                    return True

        return False


class StartEC2Instance(StopEC2Instance):

    def __init__(self, task_data: dict):
        super().__init__(task_data=task_data)

    @property
    def task_type(self) -> str:
        return 'Start EC2 instance'

    def do(self):
        return self.aws_client.start_instances(
            InstanceIds=[self.get_task_attribute(self.INSTANCE_ID_KEY)],
        )

    def is_done(self) -> bool:
        return self.is_state_equals(requested_state='running')


class RebootEC2Instance(StopEC2Instance):

    def __init__(self, task_data: dict):
        super().__init__(task_data=task_data)

    @property
    def task_type(self) -> str:
        return 'Start EC2 instance'

    def do(self):
        resp = self.aws_client.reboot_instances(
            InstanceIds=[self.get_task_attribute(self.INSTANCE_ID_KEY)],
        )
        sleep(90.0)
        return resp

    def is_done(self) -> bool:
        return self.is_state_equals(requested_state='running')


class VerifyECSServicesCount(AutomationTask):

    CLUSTER_KEY = 'cluster'
    SERVICES_KEY = 'services'
    SERVICE_RUNNING_COUNT_KEY = 'runningCount'
    SERVICE_DESIRED_COUNT_KEY = 'desiredCount'

    @property
    def validation_schema(self) -> dict:
        return {
            self.CLUSTER_KEY: {
                'required': True,
                'type': 'string'
            },
            self.SERVICES_KEY: {
                'required': True,
                'type': 'list'
            },
        }

    aws_client = boto3.client('ecs')

    def __init__(self, task_data: dict):
        super().__init__(task_data=task_data)

    @property
    def task_type(self) -> str:
        return 'Verify ECS services count'

    def do(self):
        return 'OK'

    def is_done(self) -> bool:
        desc = self.aws_client.describe_services(
            cluster=self.get_task_attribute(self.CLUSTER_KEY),
            services=self.get_task_attribute(self.SERVICES_KEY),
        )
        for service in desc[self.SERVICES_KEY]:
            desired_count = service[self.SERVICE_DESIRED_COUNT_KEY]
            running_count = service[self.SERVICE_RUNNING_COUNT_KEY]
            if desired_count != running_count:
                Log.debug(
                    f"Service {service['serviceName']} is not ready, "
                    f"desired count is {desired_count} but running count is {running_count}."
                )
                return False
        return True


class UpdateServicesWithLastRevision(VerifyECSServicesCount):

    def __init__(self, task_data: dict):
        super().__init__(task_data=task_data)

    @property
    def task_type(self) -> str:
        return 'Update ECS services with last revision'

    def do(self):
        result = []
        for service in self.get_task_attribute(self.SERVICES_KEY):

            desc = self.aws_client.update_service(
                cluster=self.get_task_attribute(self.CLUSTER_KEY),
                service=service,
                taskDefinition=self._get_task_definition_family(service=service)
            )
            result.append(desc)

        return result

    def _get_task_definition_family(self, service: str) -> str:
        desc = self.aws_client.describe_services(
            cluster=self.get_task_attribute(self.CLUSTER_KEY),
            services=[service]
        )
        services_desc = desc['services']
        Log.debug(f'Services description: {services_desc}')
        if len(services_desc) == 1:
            return services_desc[0]['taskDefinition'].split('/')[1].split(':')[0]
        else:
            raise ValueError(f'Expected only one service, got {len(services_desc)}.')

    def is_done(self) -> bool:
        return True


class ChangeTaskDefinitionContainerImage(AutomationTask):

    TASK_DEFINITION_KEY = 'task_definition'
    IMAGE_KEY = 'image'
    CONTAINER_DEFINITIONS_KEY = 'containerDefinitions'

    @property
    def validation_schema(self) -> dict:
        return {
            self.TASK_DEFINITION_KEY: {
                'required': True,
                'type': 'string'
            },
            self.IMAGE_KEY: {
                'required': True,
                'type': 'string'
            },
        }

    aws_client = boto3.client('ecs')

    def __init__(self, task_data: dict):
        super().__init__(task_data=task_data)

    @property
    def task_type(self) -> str:
        return 'Create ECS Task Definition revision'

    def do(self):
        desc = self.aws_client.describe_task_definition(
            taskDefinition=self.get_task_attribute(self.TASK_DEFINITION_KEY)
        )
        Log.debug(f'Task definition description: {desc}')
        container_definitions = desc['taskDefinition'][self.CONTAINER_DEFINITIONS_KEY]
        for container_def in container_definitions:
            image_uri = self.get_task_attribute(self.IMAGE_KEY)
            if ':' not in image_uri:
                image_uri = f"{container_def[self.IMAGE_KEY].split(':')[0]}:{image_uri}"
            container_def[self.IMAGE_KEY] = image_uri

        result = self.aws_client.register_task_definition(
            family=self.get_task_attribute(self.TASK_DEFINITION_KEY),
            containerDefinitions=container_definitions
        )

        return result

    def is_done(self) -> bool:
        return True

