#
#  Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an
# express license agreement from NVIDIA CORPORATION is strictly
# prohibited.
#

import functools
import logging
import socket
import traceback
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, List, Optional, Type

import torch
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.plugins.io import AsyncCheckpointIO
from lightning.pytorch.utilities.data import extract_batch_size
from lightning.pytorch.utilities.types import STEP_OUTPUT
from tzlocal import get_localzone

from training_telemetry.config import TelemetryConfig
from training_telemetry.events import Event, EventName
from training_telemetry.metrics import ApplicationMetrics, CheckpointMetrics, IterationMetrics
from training_telemetry.provider import Provider
from training_telemetry.recorder import Recorder
from training_telemetry.spans import Span, SpanName
from training_telemetry.torch.utils import end_monitoring_flops, get_rank, get_world_size, start_monitoring_flops
from training_telemetry.utils import get_current_time, get_logger
from training_telemetry.verbosity import Verbosity

# This should be calculated as soon as this file is imported, giving us a better approximation of the application start time.
# An alternative would be to make the caller pass it to on_app_start() but the signature of on_app_start() in the Nemo BaseCallback is generic.
_START_TIME = get_current_time()
_logger = get_logger(__name__)


@dataclass
class SpanSummary:
    """
    A class for tracking the number of times a span has been stopped and the total elapsed time for that span.
    """

    num: int = 0
    elapsed: float = 0

    def add(self, elapsed: float) -> None:
        self.num += 1
        self.elapsed += elapsed

    def reset(self) -> None:
        self.num = 0
        self.elapsed = 0


class IterationState:
    """Contains state that is tracked for each iteration"""

    def __init__(self) -> None:
        self.num_samples: int = 0
        self.num_flops: float = 0
        self.is_async_checkpoint: bool = False
        self.span_summaries: dict[SpanName, SpanSummary] = defaultdict(SpanSummary)
        self.last_logged_step: int = -1

    def reset(self) -> None:
        self.num_samples = 0
        self.num_flops = 0
        for span_name in self.span_summaries.keys():
            self.span_summaries[span_name].reset()


class TelemetryCallback(Callback):
    """
    Records training telemetry metrics during the training loop by using the PyTorch Lightning's callback system.
    Also records application metrics before and after the training loop, but only if running with Nemo version >= 25.09 or
    if the user application calls these additional callbacks manually. Examples of non PTL functions are:
    - on_app_start() and on_app_end()
    - on_dataloader_init_start() and on_dataloader_init_end()
    - on_optimizer_init_start() and on_optimizer_init_end()
    - on_load_checkpoint_start() and on_load_checkpoint_end()

    Checkpoint save is intercepted by wrapping the trainer.save_checkpoint method to add telemetry callbacks before and after
    the checkpoint save operation.

    Args:
        config: The telemetry configuration. This determines how the telemetry is reported, e.g. using python logging or a
         JSON file, or OpenTelemetry. It also determines the application properties, e.g. the job name, job ID, and environment
         where the application is running.
        batch_flops: The floating point operations per batch of the training data for the forward and backward passes.
          If they are known from first principles, they can be provided using this parameter, otherwise the callback
          will try to calculate them with pytorch utitlities but this is not guaranteed to work every time. If the batch flops
          chnage overtime, then this parameters should not be passed, or the callback property batch_flops should be updated manually.
        async_io_checkpoint_classes: The classes of checkpoint IO that are asynchronous.
         Pass a list of classes if using asynchronous checkpointing, so that the correct
         checkpoint strategy can be reported.

    See Also:
        https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#hooks
    """

    def __init__(
        self,
        config: TelemetryConfig,
        batch_flops: float = 0,
        async_io_checkpoint_classes: List[Type[Any]] | None = None,
    ):
        super().__init__()
        self.config = config
        self.log_interval = config.application.log_interval
        self.batch_flops = batch_flops
        self.batch_flops_monitor: Optional[Any] = None
        self.async_io_checkpoint_classes = async_io_checkpoint_classes or []
        self.state = IterationState()
        self._recorder: Optional[Recorder] = (
            None  # Nemo needs this to be pickleable and it isn't, so need to postpone creating it
        )
        self._spans: dict[SpanName, Span] = {}

    def _start_span(
        self,
        span_name: Any,
        metrics: Any = None,
        verbosity: Any = None,
        color: Any = None,
        start_time: Optional[float] = None,
    ) -> Span:
        """Start a span with the given event name and optional metrics, color, and verbosity"""
        if span_name in self._spans:
            _logger.warning(f"Span {span_name} already started, stopping it but this is unexpected")
            self.recorder.stop(self._spans[span_name])
            del self._spans[span_name]
        if verbosity is None:
            verbosity = Verbosity.INFO
        if start_time is None:
            start_time = get_current_time()
        span = self.recorder.start(
            name=span_name,
            color=color,
            start_time=start_time,
            verbosity=verbosity,
            metrics=metrics,
        )
        self._spans[span_name] = span
        return span

    def _stop_span(self, span_name: Any, metrics: Any = None) -> float:
        """Stop a span and return the elapsed time"""
        span = self._spans.get(span_name)
        if span is None:
            _logger.warning(f"Span {span_name} was not started, this is unexpected")
            return 0.0
        if metrics is not None:
            span.add_metrics(metrics)

        self.recorder.stop(span)
        del self._spans[span_name]
        return span.duration.elapsed

    def _add_checkpoint_callbacks(self, trainer: Trainer) -> None:
        """Wrap the trainer.save_checkpoint method to add telemetry callbacks."""
        method = trainer.save_checkpoint

        @functools.wraps(method)
        def wrapper(filepath: _PATH, *args: Any, **kwargs: Any) -> None:
            self.on_start_save_checkpoint(trainer, filepath)
            exc = None
            try:
                method(filepath, *args, **kwargs)
            except Exception as e:
                exc = e
                raise e
            finally:
                self.on_end_save_checkpoint(trainer, exc)

        wrapper.orig_method = method  # type: ignore[attr-defined]
        trainer.save_checkpoint = wrapper  # type: ignore[method-assign]

    def _start_monitoring_flops(self) -> None:
        """
        Start monitoring FLOPS using pytorch utilities.
        This is called at the start of the first batch, unless the FLOPS have been set manually by the user.
        """
        assert self.batch_flops_monitor is None
        self.batch_flops_monitor = start_monitoring_flops()

    def _end_monitoring_flops(self) -> None:
        """
        End monitoring FLOPS using pytorch utilities. This is called at the beginning of the optimizer step of the first training batch,
        and at the end of the first training batch, in case the optimizer step callback is not called.
        """
        assert self.batch_flops_monitor is not None
        self.batch_flops = end_monitoring_flops(self.batch_flops_monitor)
        self.batch_flops_monitor = None
        if self.batch_flops == 0:
            self.batch_flops = -1  # This ensures we don't try again on the second batch if we failed to get the FLOPs

    @property
    def recorder(self) -> Recorder:
        if self._recorder is None:
            self._recorder = Provider.set_provider(self.config).recorder
        return self._recorder

    def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
        if stage != "fit":
            return

        if trainer.save_checkpoint is not None:
            self._add_checkpoint_callbacks(trainer)

        if isinstance(trainer.strategy.checkpoint_io, AsyncCheckpointIO):
            self.state.is_async_checkpoint = True
        elif any(isinstance(trainer.strategy.checkpoint_io, klass) for klass in self.async_io_checkpoint_classes):
            self.state.is_async_checkpoint = True
        else:
            self.state.is_async_checkpoint = False

        self.on_app_start()
        self.record_app_metrics(trainer)

    def teardown(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
        if stage != "fit":
            return

        self.on_app_end()

        # Restore the save checkpoint method
        if trainer.save_checkpoint is not None and hasattr(trainer.save_checkpoint, "orig_method"):
            trainer.save_checkpoint = trainer.save_checkpoint.orig_method  # type: ignore[method-assign]

    def record_app_metrics(self, trainer: Trainer) -> None:
        span = self._spans.get(SpanName.MAIN_FUNCTION)
        if span is None:
            _logger.warning("Main function span not found, this is unexpected")
            return

        event_name = EventName.SPAN_ATTRIBUTES
        checkpoint_enabled = bool(trainer.checkpoint_callbacks)
        metrics = ApplicationMetrics.create(
            rank=get_rank(),
            world_size=get_world_size(),
            node_name=socket.gethostname(),
            timezone=str(get_localzone()),
            total_iterations=int(trainer.estimated_stepping_batches),
            checkpoint_enabled=checkpoint_enabled,
        )
        self.recorder.event(Event.create(event_name, metrics), span)

    # App lifecycle (from Nemo BaseCallback version >= 25.09), or otherwise called from setup() and teardown()
    def on_app_start(self, *args, **kwargs) -> None:  # type: ignore [no-untyped-def]
        """Called when the application starts."""
        if SpanName.MAIN_FUNCTION in self._spans:
            return  # This callback may be called multiple times

        global _START_TIME
        self._start_span(SpanName.MAIN_FUNCTION, start_time=_START_TIME)

    def on_app_end(self, *args, **kwargs) -> None:  # type: ignore [no-untyped-def]
        """Called when the application ends."""
        if SpanName.MAIN_FUNCTION not in self._spans:
            return  # This callback may be called multiple times

        self._stop_span(SpanName.MAIN_FUNCTION)

    # For checkpoint save callbacks, we add our own wrapper on trainer.save_checkpoint, see setup() and teardown() functions.
    # Nemo callbacks are also available, but it looks like they may be called in child processes and do not pass sufficient
    # information to the callback, so we avoid them. It's important that the callback names do not clash, so we moved the "start"
    # and "end" at the beginning of the function names.
    def on_start_save_checkpoint(self, trainer: Trainer, filepath: _PATH) -> None:
        if self.state.is_async_checkpoint:
            span_name = SpanName.CHECKPOINT_SAVE_ASYNC
        else:
            span_name = SpanName.CHECKPOINT_SAVE_SYNC

        self._start_span(span_name)

    def on_end_save_checkpoint(
        self,
        trainer: Trainer,
        exception: Optional[Exception],
    ) -> None:
        if self.state.is_async_checkpoint:
            span_name = SpanName.CHECKPOINT_SAVE_ASYNC
        else:
            span_name = SpanName.CHECKPOINT_SAVE_SYNC

        metrics = CheckpointMetrics.create(
            current_iteration=trainer.global_step,
        )
        self._stop_span(span_name, metrics)

    # Dataloader lifecycle (from Nemo BaseCallback version >= 25.09) or called manually by the user
    def on_dataloader_init_start(self, *args, **kwargs) -> None:  # type: ignore [no-untyped-def]
        """
        Called when dataloader initialization starts.
        This is not called by PyTorch Lightning, but is called by Nemo version >= 25.09.
        """
        self._start_span(SpanName.DATA_LOADER_INIT)

    def on_dataloader_init_end(self, *args, **kwargs) -> None:  # type: ignore [no-untyped-def]
        """
        Called when dataloader initialization ends.
        This is not called by PyTorch Lightning, but is called by Nemo version >= 25.09.
        """
        self._stop_span(SpanName.DATA_LOADER_INIT)

    # Optimizer lifecycle (from Nemo BaseCallback version >= 25.09) or called manually by the user
    def on_optimizer_init_start(self, *args, **kwargs) -> None:  # type: ignore [no-untyped-def]
        """
        Called when optimizer initialization starts.
        This is not called by PyTorch Lightning, but is called by Nemo version >= 25.09.
        """
        self._start_span(SpanName.OPTIMIZER_INIT)

    def on_optimizer_init_end(self, *args, **kwargs) -> None:  # type: ignore [no-untyped-def]
        """
        Called when optimizer initialization ends.
        This is not called by PyTorch Lightning, but is called by Nemo version >= 25.09.
        """
        self._stop_span(SpanName.OPTIMIZER_INIT)

    # Checkpoint load lifecycle (from Nemo BaseCallback version >= 25.09) or called manually by the user
    def on_load_checkpoint_start(self, *args, **kwargs) -> None:  # type: ignore [no-untyped-def]
        """
        Called when checkpoint loading starts.
        This is not called by PyTorch Lightning, but is called by Nemo version >= 25.09.
        """
        self._start_span(SpanName.CHECKPOINT_LOAD)

    def on_load_checkpoint_end(self, *args, **kwargs) -> None:  # type: ignore [no-untyped-def]
        """
        Called when checkpoint loading ends.
        This is not called by PyTorch Lightning, but is called by Nemo version >= 25.09.
        """
        self._stop_span(SpanName.CHECKPOINT_LOAD)

    # Remaining functions are from PyTorch Lightning
    def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
        self._start_span(SpanName.TRAINING_LOOP)
        self.state.reset()

    def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
        # Log any remaining metrics that haven't been logged yet
        if trainer.global_step != self.state.last_logged_step:
            self._log_iteration_metrics(trainer)
        self._stop_span(SpanName.TRAINING_LOOP)

    def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int) -> None:
        self._start_span(SpanName.ITERATION, verbosity=Verbosity.PROFILING)
        self._start_span(SpanName.MODEL_FORWARD, verbosity=Verbosity.PROFILING)

        if self.batch_flops == 0:
            self._start_monitoring_flops()

    # Note that this won't be called by Nemo when using the Megatron Core models and training strategy. See function docstring for more details.
    def on_before_backward(
        self, trainer: Trainer, pl_module: LightningModule, loss: torch.Tensor, optimizer_idx: int = 0
    ) -> None:
        """
        The before and after backward callbacks are called before and after the backward pass, respectively.
        Note that when using Megatron Core models and Megatron training strategies in Nemo, these will never be called. This is
        because Megatron Core does not invoke model.backward(), it only invokes model.forward() and then uses
        a loss function to compute the backward pass, which cannot be tracked with PTL callbacks. So when
        running Megatron models in Nemo, the forward time actually includes the backward pass. This is the same
        for MLM, where forward-backward are timed together in MLM log messages (when enabled).
        """
        if SpanName.MODEL_FORWARD in self._spans:
            forward_elapsed = self._stop_span(SpanName.MODEL_FORWARD)
            self.state.span_summaries[SpanName.MODEL_FORWARD].add(forward_elapsed)

        span_name = SpanName.MODEL_BACKWARD
        self._start_span(span_name, verbosity=Verbosity.PROFILING)

    def on_after_backward(self, trainer: Trainer, pl_module: LightningModule) -> None:
        span_name = SpanName.MODEL_BACKWARD
        if span_name in self._spans:
            backward_elapsed = self._stop_span(span_name)
            self.state.span_summaries[span_name].add(backward_elapsed)

    def on_before_optimizer_step(
        self, trainer: Trainer, pl_module: LightningModule, optimizer: torch.optim.Optimizer
    ) -> None:
        # Stop MODEL_FORWARD if it's still running (happens with Megatron Core models and training strategies in Nemo that do not invoke model.backward())
        if SpanName.MODEL_FORWARD in self._spans:
            forward_elapsed = self._stop_span(SpanName.MODEL_FORWARD)
            self.state.span_summaries[SpanName.MODEL_FORWARD].add(forward_elapsed)

        if self.batch_flops_monitor is not None:
            self._end_monitoring_flops()

        span_name = SpanName.OPTIMIZER_UPDATE
        self._start_span(span_name, verbosity=Verbosity.PROFILING)

    def on_train_batch_end(
        self,
        trainer: Trainer,
        pl_module: LightningModule,
        outputs: STEP_OUTPUT,
        batch: Any,
        batch_idx: int,
    ) -> None:
        # Stop monitoring FLOPS if still running
        if self.batch_flops_monitor is not None:
            self._end_monitoring_flops()

        # Stop MODEL_FORWARD if it's still running (happens with Megatron Core models and training strategies in Nemo that do not invoke model.backward())
        if SpanName.MODEL_FORWARD in self._spans:
            forward_elapsed = self._stop_span(SpanName.MODEL_FORWARD)
            self.state.span_summaries[SpanName.MODEL_FORWARD].add(forward_elapsed)

        # Stop OPTIMIZER_UPDATE if it was started
        if SpanName.OPTIMIZER_UPDATE in self._spans:
            optimizer_update_elapsed = self._stop_span(SpanName.OPTIMIZER_UPDATE)
            self.state.span_summaries[SpanName.OPTIMIZER_UPDATE].add(optimizer_update_elapsed)

        iteration_elapsed = self._stop_span(SpanName.ITERATION)
        self.state.span_summaries[SpanName.ITERATION].add(iteration_elapsed)

        batch_size = extract_batch_size(batch)
        self.state.num_samples += batch_size
        if self.batch_flops > 0:
            self.state.num_flops += self.batch_flops

        # Send an event only when global_step has changed and is a multiple of log_interval.
        # This prevents logging multiple times per optimizer step when using gradient accumulation.
        if trainer.global_step != self.state.last_logged_step and trainer.global_step % self.log_interval == 0:
            self._log_iteration_metrics(trainer)
            self.state.last_logged_step = trainer.global_step

    def _log_iteration_metrics(self, trainer: Trainer) -> None:
        """Log the iteration metrics"""
        iterations = self.state.span_summaries[SpanName.ITERATION].num
        if iterations == 0:
            return

        avg_iteration_time = self.state.span_summaries[SpanName.ITERATION].elapsed / iterations
        avg_forward_time = (
            (
                self.state.span_summaries[SpanName.MODEL_FORWARD].elapsed
                / self.state.span_summaries[SpanName.MODEL_FORWARD].num
            )
            if self.state.span_summaries[SpanName.MODEL_FORWARD].num > 0
            else 0
        )
        avg_backward_time = (
            (
                self.state.span_summaries[SpanName.MODEL_BACKWARD].elapsed
                / self.state.span_summaries[SpanName.MODEL_BACKWARD].num
            )
            if self.state.span_summaries[SpanName.MODEL_BACKWARD].num > 0
            else 0
        )

        avg_optimizer_update_time = (
            (
                self.state.span_summaries[SpanName.OPTIMIZER_UPDATE].elapsed
                / self.state.span_summaries[SpanName.OPTIMIZER_UPDATE].num
            )
            if self.state.span_summaries[SpanName.OPTIMIZER_UPDATE].num > 0
            else 0
        )

        if avg_forward_time > 0 or avg_backward_time > 0:
            # For Megatron models in Nemo, avg_backward_time will be 0 but avg_forward_time will include both
            flops_per_second = self.state.num_flops / (avg_forward_time + avg_backward_time)
            tflops_per_second = flops_per_second / 1e12
        else:
            tflops_per_second = 0

        metrics = IterationMetrics.create(
            current_iteration=trainer.global_step,
            num_iterations=self.log_interval,
            interval=self.log_interval,
            batch_size=int(self.state.num_samples / iterations),
            average_iteration_time=avg_iteration_time,
            tflops=tflops_per_second,
            average_forward_time=avg_forward_time,
            average_optimizer_update_time=avg_optimizer_update_time,
        )

        if avg_backward_time > 0:
            metrics.add_metric("average_backward_time", avg_backward_time)

        self.recorder.event(Event.create(EventName.TRAINING_ITERATIONS, metrics))
        self.state.reset()

    def on_exception(self, trainer: Trainer, pl_module: LightningModule, exception: BaseException) -> None:
        if isinstance(exception, Exception):
            self.recorder.error(str(exception), exception)
        else:
            message = f"{type(exception).__name__}: {str(exception)}\n" + "".join(traceback.format_exception(exception))
            self.recorder.error(message)

    def on_validation_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
        self._start_span(SpanName.VALIDATION_LOOP)
        self.state.reset()

    def on_validation_batch_start(
        self,
        trainer: Trainer,
        pl_module: LightningModule,
        batch: Any,
        batch_idx: int,
        dataloader_idx: int = 0,
    ) -> None:
        # Skip because it may not run all callbacks during sanity checking
        if trainer.sanity_checking:
            return

        self._start_span(SpanName.ITERATION, verbosity=Verbosity.PROFILING)

    def on_validation_batch_end(
        self,
        trainer: Trainer,
        pl_module: LightningModule,
        outputs: STEP_OUTPUT,
        batch: Any,
        batch_idx: int,
        dataloader_idx: int = 0,
    ) -> None:
        # Skip because it may not run all callbacks during sanity checking
        if trainer.sanity_checking:
            return

        validation_elapsed = self._stop_span(SpanName.ITERATION)
        self.state.span_summaries[SpanName.ITERATION].add(validation_elapsed)

    def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
        if self.state.span_summaries[SpanName.ITERATION].num == 0:
            self._stop_span(SpanName.VALIDATION_LOOP)
            return

        avg_validation_time = (
            self.state.span_summaries[SpanName.ITERATION].elapsed / self.state.span_summaries[SpanName.ITERATION].num
        )
        metrics = IterationMetrics.create(
            current_iteration=trainer.global_step,
            num_iterations=self.state.span_summaries[SpanName.ITERATION].num,
            average_iteration_time=avg_validation_time,
        )
        self._stop_span(SpanName.VALIDATION_LOOP, metrics)
        self.state.reset()
