# coding: utf-8

__author__ = "Ondrej Jurcak"

import logging
from configparser import ConfigParser

from aws_xray_sdk import global_sdk_config
from aws_xray_sdk.core import xray_recorder, patch_all, patch
from aws_xray_sdk.core.utils import stacktrace
from aws_xray_sdk.ext.flask.middleware import XRayMiddleware
from aws_xray_sdk.ext.flask_sqlalchemy.query import XRayFlaskSqlAlchemy
from aws_xray_sdk.ext.util import construct_xray_header, inject_trace_header
from celery import signals
from core.instrumentation.logging import log


class XRayConfigurator:
    def __init__(self, settings, application) -> None:
        if isinstance(settings, ConfigParser):
            self._init_xray(settings.get("XRAY", "ENABLED", False),
                            settings.get("XRAY", "SERVICE_NAME"),
                            settings.get("XRAY", "PLUGINS", "'ECSPlugin', 'EC2Plugin'"),
                            settings.get("XRAY", "SAMPLING", False),
                            settings.get("XRAY", "STREAM_SQL", True),
                            settings.get("XRAY", "DAEMON", "xray-daemon:2000"),
                            settings.get("XRAY", "PATCH_MODULES", None),
                            settings.get("XRAY", "SQLALCHEMY", True),
                            settings.get("XRAY", "LOGGING_ENABLED", False),
                            settings.get("XRAY", "CELERY", False),
                            application)
        else:
            self._init_xray(settings.XRAY_ENABLED,
                            settings.XRAY_NAME,
                            settings.XRAY_PLUGINS,
                            settings.XRAY_SAMPLING,
                            settings.XRAY_STREAM_SQL,
                            settings.XRAY_DAEMON,
                            settings.XRAY_PATCH_MODULES,
                            settings.XRAY_SQLALCHEMY,
                            settings.XRAY_LOGGING_ENABLED,
                            settings.XRAY_CELERY,
                            application)

    def _init_xray(self, enabled, name, plugins, sampling, stream_sql, daemon, patch_modules, sqlalchemy, logging_enabled, celery_enabled, application):
        if enabled:
            log.info("xray sdk enabled " + str(enabled))
            global_sdk_config.set_sdk_enabled(True)
            xray_recorder.configure(service=name,
                                    plugins=eval(plugins),
                                    sampling=sampling,
                                    stream_sql=stream_sql,
                                    daemon_address=daemon,
                                    context_missing="LOG_ERROR")
            if not patch_modules:
                patch_all()
            else:
                patch(eval(patch_modules))

            if sqlalchemy:
                log.info("XRayFlaskSqlAlchemy enabled")
                XRayFlaskSqlAlchemy(application)

            XRayMiddleware(application, xray_recorder)

            if logging_enabled:
                log.info("XRay logging enabled")
                logging.getLogger("aws_xray_sdk").setLevel(logging.DEBUG)
            else:
                log.info("XRay logging disabled")
                logging.getLogger("aws_xray_sdk").setLevel(logging.NOTSET)
        else:
            log.info("xray sdk disabled")
            global_sdk_config.set_sdk_enabled(False)

@signals.task_prerun.connect
def task_prerun(task_id, task, *args, **kwargs):
        log.info("task prerun task_id: " + task_id)
        xray_header = construct_xray_header(task.request)
        print(xray_header.root + " " + xray_header.parent)
        segment = xray_recorder.begin_segment(
            name=task.name,
            traceid=xray_header.root,
            parent_id=xray_header.parent,
        )
        segment.save_origin_trace_header(xray_header)
        segment.put_metadata('task_id', task_id, namespace='celery')

@signals.task_postrun.connect()
def task_postrun(task_id, *args, **kwargs):
        log.info("task post run task_id: " + task_id)
        xray_recorder.end_segment()

@signals.before_task_publish.connect
def before_task_publish(sender, headers, **kwargs):
        log.info("before taskpublish sender: " + sender)
        current_segment = xray_recorder.current_segment()
        subsegment = xray_recorder.begin_subsegment(
            name=sender,
            namespace='remote',
        )

        if subsegment is None:
            # Not in segment
            return
        subsegment.put_metadata('task_id', headers.get("id"), namespace='celery')
        inject_trace_header(headers, subsegment)

@signals.after_task_publish.connect
def xray_after_task_publish(**kwargs):
        log.info("after task publish")
        xray_recorder.end_subsegment()

@signals.task_failure.connect
def xray_task_failure(einfo, **kwargs):
        log.info("task_failure")
        segment = xray_recorder.current_segment()
        if einfo:
            stack = stacktrace.get_stacktrace(limit=xray_recorder.max_trace_back)
            segment.add_exception(einfo.exception, stack)
