import traceback
import opentracing
from opentracing.ext import tags
import six


class DjangoTracing(object):
    """
    @param tracer the OpenTracing tracer to be used
    to trace requests using this DjangoTracing
    """

    def __init__(self, tracer=None, start_span_cb=None):
        if start_span_cb is not None and not callable(start_span_cb):
            raise ValueError("start_span_cb is not callable")

        self._tracer_implementation = tracer
        self._start_span_cb = start_span_cb
        self._current_scopes = {}
        self._trace_all = False
        self._trace_response_header_enabled = True

    def _get_tracer_impl(self):
        return self._tracer_implementation

    @property
    def tracer(self):
        if self._tracer_implementation:
            return self._tracer_implementation
        else:
            return opentracing.tracer

    @property
    def _tracer(self):
        """DEPRECATED"""
        return self.tracer

    def get_span(self, request):
        """
        @param request
        Returns the span tracing this request
        """
        scope = self._current_scopes.get(request, None)
        return None if scope is None else scope.span

    def trace(self, *attributes):
        """
        Function decorator that traces functions
        NOTE: Must be placed after the @app.route decorator
        @param attributes any number of flask.Request attributes
        (strings) to be set as tags on the created span
        """

        def decorator(view_func):
            # TODO: do we want to provide option of overriding
            # trace_all_requests so that they can trace certain attributes
            # of the request for just this request (this would require to
            # reinstate the name-mangling with a trace identifier, and another
            # settings key)

            def wrapper(request, *args, **kwargs):
                # if tracing all already, return right away.
                if self._trace_all:
                    return view_func(request)

                # otherwise, apply tracing.
                try:
                    self._apply_tracing(request, view_func, list(attributes))
                    r = view_func(request, *args, **kwargs)
                except Exception as exc:
                    self._finish_tracing(request, error=exc)
                    raise

                self._finish_tracing(request, r)
                return r

            return wrapper

        return decorator

    def _apply_tracing(self, request, view_func, attributes):
        """
        Helper function to avoid rewriting for middleware and decorator.
        Returns a new span from the request with logged attributes and
        correct operation name from the view_func.
        """
        # strip headers for trace info
        headers = {}
        for k, v in six.iteritems(request.META):
            k = k.lower().replace("_", "-")
            if k.startswith("http-"):
                k = k[5:]
            headers[k] = v

        # start new span from trace info
        operation_name = view_func.__name__
        try:
            span_ctx = self.tracer.extract(opentracing.Format.HTTP_HEADERS, headers)
            scope = self.tracer.start_active_span(operation_name, child_of=span_ctx)
        except (
            opentracing.InvalidCarrierException,
            opentracing.SpanContextCorruptedException,
        ):
            scope = self.tracer.start_active_span(operation_name)

        # add span to current spans
        self._current_scopes[request] = scope

        # standard tags
        scope.span.set_tag(tags.COMPONENT, "django")
        scope.span.set_tag(tags.SPAN_KIND, tags.SPAN_KIND_RPC_SERVER)
        scope.span.set_tag(tags.HTTP_METHOD, request.method)
        scope.span.set_tag(tags.HTTP_URL, request.get_full_path())

        # log any traced attributes
        for attr in attributes:
            if hasattr(request, attr):
                payload = str(getattr(request, attr))
                if payload:
                    scope.span.set_tag(attr, payload)

        # invoke the start span callback, if any
        self._call_start_span_cb(scope.span, request)

        return scope

    def _finish_tracing(self, request, response=None, error=None):
        scope = self._current_scopes.pop(request, None)
        if scope is None:
            return

        if error is not None:
            scope.span.set_tag(tags.ERROR, True)
            scope.span.set_tag("sfx.error.message", str(error))
            scope.span.set_tag("sfx.error.object", str(error.__class__))
            scope.span.set_tag("sfx.error.kind", error.__class__.__name__)
            scope.span.set_tag("sfx.error.stack", traceback.format_exc())

        if response is not None:
            scope.span.set_tag(tags.HTTP_STATUS_CODE, response.status_code)
            if self._trace_response_header_enabled:
                trace_id = getattr(scope.span.context, "trace_id", 0)
                span_id = getattr(scope.span.context, "span_id", 0)
                if trace_id and span_id:
                    add_response_header(
                        response, "Access-Control-Expose-Headers", "Server-Timing"
                    )
                    add_response_header(
                        response,
                        "Server-Timing",
                        'traceparent;desc="00-{trace_id}-{span_id}-01"'.format(
                            trace_id="{:016x}".format(trace_id),
                            span_id="{:016x}".format(span_id),
                        ),
                    )

        scope.close()

    def _call_start_span_cb(self, span, request):
        if self._start_span_cb is None:
            return

        try:
            self._start_span_cb(span, request)
        except Exception:
            pass


def initialize_global_tracer(tracing):
    """
    Initialisation as per https://github.com/opentracing/opentracing-python/blob/9f9ef02d4ef7863fb26d3534a38ccdccf245494c/opentracing/__init__.py#L36 # noqa

    Here the global tracer object gets initialised once from Django settings.
    """
    if initialize_global_tracer.complete:
        return

    # DjangoTracing may be already relying on the global tracer,
    # hence check for a non-None value.
    tracer = tracing._tracer_implementation
    if tracer is not None:
        opentracing.tracer = tracer

    initialize_global_tracer.complete = True


initialize_global_tracer.complete = False


def add_response_header(response, name, value):
    old_value = response.get(name, "")
    if old_value:
        old_value += ","
    value = old_value + value
    response[name] = value
