#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
#

import os
import traceback
from datetime import datetime
import json
import logging
import numpy as np

MODEL_LOG_DATETIME_FORMAT = '%Y-%m-%dT%H:%M:%S'


class LogType():
    PLOT = 'PLOT'
    METRICS = 'METRICS'
    MESSAGE = 'MESSAGE'


class LoggerUtils():
    '''
    Allows models to log messages & metrics during model training.

    This should NOT be initiailized outside of the module. Instead,
    import the global ``utils`` instance from the module ``singa_auto.model``
    and use ``utils.logger``.

    For example:

    ::

        from singa_auto.model import utils
        ...
        def train(self, dataset_path, **kwargs):
            ...
            utils.logger.log('Starting model training...')
            utils.logger.define_plot('Precision & Recall', ['precision', 'recall'], x_axis='epoch')
            ...
            utils.logger.log(precision=0.1, recall=0.6, epoch=1)
            ...
            utils.logger.log('Ending model training...')
            ...

    '''

    def __init__(self):
        # By default, set a logging handler to print to stdout (for debugging)
        logger = logging.getLogger(__name__)
        logger.setLevel(level=logging.INFO)
        logger.addHandler(LoggerUtilsDebugHandler())
        self._logger = logger

    def define_loss_plot(self):
        '''
        Convenience method of defining a plot of ``loss`` against ``epoch``.
        To be used with :meth:`singa_auto.model.LoggerUtils.log_loss`.
        '''
        self.define_plot('Loss Over Epochs', ['loss'], x_axis='epoch')

    def log_loss(self, loss, epoch):
        '''
        Convenience method for logging `loss` against `epoch`.
        To be used with :meth:`singa_auto.model.LoggerUtils.define_loss_plot`.
        '''
        self.log(loss=loss, epoch=epoch)

    def define_plot(self, title, metrics, x_axis=None):
        '''
        Defines a plot for a set of metrics for analysis of model training.
        By default, metrics will be plotted against time.

        For example, a model's precision & recall logged with e.g. ``log(precision=0.1, recall=0.6, epoch=1)``
        can be visualized in the plots generated by
        ``define_plot('Precision & Recall', ['precision', 'recall'])`` (against time) or
        ``define_plot('Precision & Recall', ['precision', 'recall'], x_axis='epoch')`` (against epochs).

        Only call this method in :meth:`singa_auto.model.BaseModel.train`.

        :param str title: Title of the plot
        :param metrics: List of metrics that should be plotted on the y-axis
        :type metrics: str[]
        :param str x_axis: Metric that should be plotted on the x-axis, against all other metrics. Defaults to ``'time'``, which is automatically logged
        '''
        self._log(LogType.PLOT, {
            'title': title,
            'metrics': metrics,
            'x_axis': x_axis
        })

    def log(self, msg='', **metrics):
        '''
        Logs a message and/or a set of metrics at a single point in time.

        Logged messages will be viewable on SINGA-Auto's administrative UI.

        To visualize logged metrics on plots, a plot must be defined via :meth:`singa_auto.model.LoggerUtils.define_plot`.

        Only call this method in :meth:`singa_auto.model.BaseModel.train` and :meth:`singa_auto.model.BaseModel.evaluate`.

        :param str msg: Message to be logged
        :param metrics: Set of metrics & their values to be logged as ``{ <metric>: <value> }``, where ``<value>`` should be a number.
        :type metrics: dict[str, int|float]
        '''
        if msg:
            self._log(LogType.MESSAGE, {'message': str(msg)})

        if metrics:
            metrics = self._validate_metrics(metrics)
            self._log(LogType.METRICS, metrics)

    # - INTERNAL METHOD -
    # Set the Python logger internally used.
    # During model training, this method will be called by SINGA-Auto to inject a Python logger
    # to generate logs for an instance of model training.
    def set_logger(self, logger):
        self._logger = logger

    def _validate_metrics(self, metrics):
        return {n: self._validate_metric(n, v) for (n, v) in metrics.items()}

    def _log(self, log_type, log_dict={}):
        log_dict['type'] = log_type
        log_dict['time'] = datetime.now().strftime(MODEL_LOG_DATETIME_FORMAT)
        log_line = json.dumps(log_dict)
        self._logger.info(log_line)

    def _validate_metric(self, name, value):
        if isinstance(value, np.int64) or isinstance(value, np.int32):
            return int(value)
        elif isinstance(value, np.float64) or isinstance(value, np.float32):
            return float(value)

        if not isinstance(value, int) and not isinstance(value, float):
            raise TypeError(
                'Metric of name "{}" should be an `int` or `float`, but is of `{}`'
                .format(name, type(value)))

        return value

    @staticmethod
    # Parses a logged line into a dictionary.
    def parse_log_line(log_line):
        try:
            return json.loads(log_line)
        except ValueError:
            # An unserializable log line is a message
            return {'type': LogType.MESSAGE, 'message': log_line}

    @staticmethod
    # Parses logs into (messages, metrics, plots) for visualization.
    def parse_logs(log_lines):
        plots = []
        metrics = []
        messages = []

        for log_line in log_lines:
            log_dict = LoggerUtils.parse_log_line(log_line)

            if 'type' not in log_dict:
                continue

            log_type = log_dict['type']
            del log_dict['type']

            if log_type == LogType.MESSAGE:
                messages.append({
                    'time': log_dict.get('time'),
                    'message': log_dict.get('message')
                })

            elif log_type == LogType.METRICS:
                metrics.append({'time': log_dict.get('time'), **log_dict})

            elif log_type == LogType.PLOT:
                plots.append({**log_dict})

        return (messages, metrics, plots)


class LoggerUtilsDebugHandler(logging.Handler):

    def __init__(self):
        logging.Handler.__init__(self)

    def emit(self, record):
        log_line = record.msg
        log_dict = LoggerUtils.parse_log_line(log_line)
        log_type = log_dict.get('type')

        if log_type == LogType.PLOT:

            title = log_dict.get('title')
            metrics = log_dict.get('metrics')
            x_axis = log_dict.get('x_axis')
            # In pylint, allow many format args in this specific case
            # pylint: disable = too-many-format-args
            self._print('Plot `{}` will be registered when this model is being trained on SINGA-Auto' \
                .format(title, ', '.join(metrics), x_axis or 'time'))

        elif log_type == LogType.METRICS:
            metrics_log = ', '.join([
                '{}={}'.format(metric, value)
                for (metric, value) in log_dict.items()
            ])
            self._print('Metric(s) logged: {}'.format(metrics_log))

        elif log_type == LogType.MESSAGE:
            msg = log_dict.get('message')
            self._print(msg)

        else:
            self._print(log_line)

    def _print(self, message):
        print('[{}][{}]'.format(__name__, str(datetime.now())), message)
