"""
This module provides a utility class for creating and managing matplotlib figures and axes.

The Figures class simplifies the process of generating plots, customizing axes, and saving figures.
It supports various types of plots, including scatter plots, line plots, histograms, and 2D histograms.
Additionally, it includes methods for adding colorbars, setting titles, and managing subplots.
"""
import matplotlib.pyplot as plt
import numpy as np
from typing import Any


class Figures:
    """
    A utility class for creating and managing matplotlib figures and axes.

    This class provides methods for initializing figures, creating subplots,
    and generating various types of plots. It also includes functionality
    for customizing axes, adding colorbars, setting titles, and saving figures.
    """

    def __init__(self) -> None:
        """
        Initialize a Figures instance with a default figure and axis.
        """
        self.fig, self.ax = plt.subplots()

    def initialise_figure(self, *args: Any, **kwargs: Any) -> None:
        """
        Initialize a new figure with optional arguments for customization.

        Args:
            *args: Positional arguments passed to `plt.subplots`.
            **kwargs: Keyword arguments passed to `plt.subplots`.
        """
        self.fig, self.ax = plt.subplots(*args, **kwargs)
        if np.shape(self.ax) == ():
            self.active_ax = self.ax

    def set_subfigure(self, row: int, cols: int) -> None:
        """
        Set the active subplot based on the specified row and column indices.

        Args:
            row (int): The row index of the subplot.
            cols (int): The column index of the subplot.

        Raises:
            IndexError: If the figure does not have subplots.
        """
        if np.shape(self.ax) == ():
            raise IndexError("This figure does not have sub-figures")
        if len(np.shape(self.ax)) == 1:
            if row == 0 and cols != 0:
                self.active_ax = self.ax[cols]
            elif row != 0 and cols == 0:
                self.active_ax = self.ax[row]
            elif row == 0 and cols == 0:
                self.active_ax = self.ax[cols]
        else:
            self.active_ax = self.ax[row, cols]

    # Base Plots
    def scatter(self, *args: Any, **kwargs: Any) -> None:
        """
        Create a scatter plot on the active axis.

        This method generates a scatter plot using matplotlib's scatter function,
        with all arguments passed through to the underlying plot method.

        Args:
            *args: Positional arguments passed to `ax.scatter`. Typically:
                - x: array-like, x-coordinates of the data points
                - y: array-like, y-coordinates of the data points
            **kwargs: Keyword arguments passed to `ax.scatter`. Common options:
                - s: marker size
                - c: marker color or color array
                - alpha: transparency level
                - marker: marker style
                - label: legend label

        Example:
            >>> fig = Figures()
            >>> fig.scatter([1, 2, 3], [4, 5, 6], c='red', s=50, alpha=0.7)
        """
        self.active_plot = self.active_ax.scatter(*args, **kwargs)

    def plot(self, *args: Any, **kwargs: Any) -> None:
        """
        Create a line plot on the active axis.

        This method generates a line plot using matplotlib's plot function,
        suitable for continuous data visualization and curve plotting.

        Args:
            *args: Positional arguments passed to `ax.plot`. Typically:
                - x: array-like, x-coordinates of the data points
                - y: array-like, y-coordinates of the data points
                - fmt: string, format string (optional, e.g., 'r-', 'bo')
            **kwargs: Keyword arguments passed to `ax.plot`. Common options:
                - color/c: line color
                - linewidth/lw: line width
                - linestyle/ls: line style ('-', '--', ':', '-.')
                - marker: marker style
                - alpha: transparency level
                - label: legend label

        Example:
            >>> fig = Figures()
            >>> fig.plot([1, 2, 3], [4, 5, 6], 'r-', linewidth=2, label='data')
        """
        self.active_plot = self.active_ax.plot(*args, **kwargs)

    def hist(self, *args: Any, **kwargs: Any) -> None:
        """
        Create a histogram on the active axis.

        Args:
            *args: Positional arguments passed to `ax.hist`.
            **kwargs: Keyword arguments passed to `ax.hist`.
        """
        self.active_plot = self.active_ax.hist(*args, **kwargs)

    def hist2d(self, *args: Any, **kwargs: Any) -> None:
        """
        Create a 2D histogram on the active axis.

        Args:
            *args: Positional arguments passed to `ax.hist2d`.
            **kwargs: Keyword arguments passed to `ax.hist2d`.
        """
        self.active_plot = self.active_ax.hist2d(*args, **kwargs)

    # Axes Manipulation
    def set_x_limits(self, **kwargs: Any) -> None:
        """
        Set the x-axis limits of the active axis.

        Args:
            **kwargs: Keyword arguments passed to `ax.set_xlim`.
        """
        self.active_ax.set_xlim(**kwargs)

    def set_y_limits(self, **kwargs: Any) -> None:
        """
        Set the y-axis limits of the active axis.

        Args:
            **kwargs: Keyword arguments passed to `ax.set_ylim`.
        """
        self.active_ax.set_ylim(**kwargs)

    def set_x_ticks(self, *args: Any, **kwargs: Any) -> None:
        """
        Set the x-axis ticks of the active axis.

        Args:
            *args: Positional arguments passed to `ax.set_xticks`.
            **kwargs: Keyword arguments passed to `ax.set_xticks`.
        """
        self.active_ax.set_xticks(*args, **kwargs)

    def set_y_ticks(self, *args: Any, **kwargs: Any) -> None:
        """
        Set the y-axis ticks of the active axis.

        Args:
            *args: Positional arguments passed to `ax.set_yticks`.
            **kwargs: Keyword arguments passed to `ax.set_yticks`.
        """
        self.active_ax.set_yticks(*args, **kwargs)

    def set_x_label(self, *args: Any, **kwargs: Any) -> None:
        """
        Set the x-axis label of the active axis.

        Args:
            *args: Positional arguments passed to `ax.set_xlabel`.
            **kwargs: Keyword arguments passed to `ax.set_xlabel`.
        """
        self.active_ax.set_xlabel(*args, **kwargs)

    def set_y_label(self, *args: Any, **kwargs: Any) -> None:
        """
        Set the y-axis label of the active axis.

        Args:
            *args: Positional arguments passed to `ax.set_ylabel`.
            **kwargs: Keyword arguments passed to `ax.set_ylabel`.
        """
        self.active_ax.set_ylabel(*args, **kwargs)

    # Extras
    def colorbar(self, *args: Any, **kwargs: Any) -> None:
        """
        Add a colorbar to the figure.

        Args:
            *args: Positional arguments passed to `fig.colorbar`.
            **kwargs: Keyword arguments passed to `fig.colorbar`.
        """
        self.fig.colorbar(self.active_plot[3], *args, **kwargs)

    def set_title(self, *args: Any, **kwargs: Any) -> None:
        """
        Set the title of the active axis.

        Args:
            *args: Positional arguments passed to `ax.set_title`.
            **kwargs: Keyword arguments passed to `ax.set_title`.
        """
        self.active_ax.set_title(*args, **kwargs)

    def save_to_disk(self, *args: Any, **kwargs: Any) -> None:
        """
        Save the figure to disk.

        Args:
            *args: Positional arguments passed to `fig.savefig`.
            **kwargs: Keyword arguments passed to `fig.savefig`.
        """
        self.fig.savefig(*args, **kwargs)

    def show(self) -> None:
        """
        Display the figure.
        """
        plt.show()


if __name__ == "__main__":
    """
    Example usage of the Figures class.
    """
    F = Figures()
    F.initialise_figure(figsize=(5, 5))

    # F.set_subfigure(0, 0)
    F.hist2d(np.random.normal(0, 1, 1000), np.random.normal(0, 1, 1000))
    F.set_x_limits(left=-2, right=2)
    F.set_x_ticks(np.linspace(-2, 2, 6))
    F.set_y_label('y-axis')
    F.set_x_label('x-axis')
    F.colorbar(label='colorbar')
    F.set_title('Title')
    F.save_to_disk('fig.png', dpi=300)
