from __future__ import annotations

import numpy as np
from dataclasses import dataclass, field
from typing import List, Sequence


@dataclass(eq=False)
class Branch:
    """
    The most basic element describing the relation between two potentials/terminals.

    Please use the more specific implementations `VoltageDrivenEdge` and `CurrentDrivenEdge` instead. Subclassing
    this base `Edge` class will have no meaning to the simulation algorithm that will simply ignore everything else.

    name:
        An optional name to set for the component. Can be set after initialization.
    v:
        The list of voltages (unit: Volts) calculated by simulation. You can access it on each edge individually.
    i:
        The list of currents (unit: Amperes) calculated by simulation. You can access it on each edge individually.
    """
    name: str = field(default=None, init=False)
    v: List[float] = field(default_factory=list, init=False, repr=False)
    i: List[float] = field(default_factory=list, init=False, repr=False)

    @property
    def p(self) -> List[float]:
        """
        :return:
            The branch's power dissipation. Generates a new list calculated out of `self.v` and `self.i`, so this
            might be an expensive operation.
        """
        return [v * i for v, i in zip(self.v, self.i)]


@dataclass(eq=False)
class CurrentBranch(Branch):
    def get_current(self, v_i: np.ndarray, dt: float) -> float:
        """
        Calculated current depending on given voltages / currents and time step.

        :param v_i:
            Voltages (measured in Volts) and currents (measured in Ampere). A vector filled with the potentials and
            branch currents as specified in the edges returned by `connect` (in order). See also `Element.connect`.
        :param dt:
            Time step (measured in seconds).
        :return:
            The current (measured in Ampere).
        """
        raise NotImplementedError

    def get_jacobian(self, v_i: np.ndarray, dt: float) -> Sequence[float]:
        """
        Returns the Jacobian matrix (i.e. the "gradient matrix") of the current function.

        For the numerical algorithm to quickly converge with more accuracy, it is necessary for electrical circuit
        problems to supply the derivations of each current / voltage function.

        This function specifically is supposed to return an array of derivatives instead of a matrix. Depending on the
        coupled branches specified by the holding component, the derivatives have to be returned in the same order.

        See also https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant

        :param v_i:
            Voltages (measured in Volts) and currents (measured in Ampere). A vector filled with the potentials and
            branch currents as specified in the edges returned by `connect` (in order). See also `Element.connect`.
        :param dt:
            Time step (measured in seconds). Note that dt / t (time) is not a variable to differentiate for in the
            Jacobi matrix.
        :return:
            The Jacobian (measured in Ampere per second).
        """
        raise NotImplementedError

    def update(self, v: float, coupled_v_i: np.ndarray, dt: float):
        """
        Updates the state of the edge from the given voltage and time step.

        This method is called after each simulation cycle in the `Circuit.simulate` function.
        If the simulation algorithm was able to determine a proper voltage and current,
        this method is called to set those calculated values as the new "truth" for the
        next time step.

        By default, this appends the calculated voltages (and therefore also currents)
        to the `v` and `i` fields. Be sure when you override this function to also
        call the base function via `super().update(v, dt)`.

        :param v:
            New voltage (measured in Volts).
        :param coupled_v_i:
            The voltages and currents of all other coupled edges (measured in Volts and Amperes respectively).
        :param dt:
            New time step (measured in seconds).
        """
        self.v.append(v)
        self.i.append(self.get_current(coupled_v_i, dt))


@dataclass(eq=False)
class VoltageBranch(Branch):
    def get_voltage(self, v_i: np.ndarray, dt: float) -> float:
        """
        Calculated current depending on given voltages / currents and time step.

        :param v_i:
            Voltages (measured in Volts) and currents (measured in Ampere). A vector filled with the potentials and
            branch currents as specified in the edges returned by `connect` (in order). See also `Element.connect`.
        :param dt:
            Time step (measured in seconds).
        :return:
            The current (measured in Ampere).
        """
        raise NotImplementedError

    def get_jacobian(self, v_i: np.ndarray, dt: float) -> Sequence[float]:
        """
        Returns the Jacobian matrix (i.e. the "gradient matrix") of the voltage function.

        For the numerical algorithm to quickly converge with more accuracy, it is necessary for electrical circuit
        problems to supply the derivations of each current / voltage function.

        This function specifically is supposed to return an array of derivatives instead of a matrix. Depending on the
        coupled branches specified by the holding component, the derivatives have to be returned in the same order.

        See also https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant

        :param v_i:
            Voltages (measured in Volts) and currents (measured in Ampere). A vector filled with the potentials and
            branch currents as specified in the edges returned by `connect` (in order). See also `Element.connect`.
        :param dt:
            Time step (measured in seconds). Note that dt / t (time) is not a variable to differentiate for in the
            Jacobi matrix.
        :return:
            The Jacobian (measured in Volt per second).
        """
        raise NotImplementedError

    def update(self, i: float, coupled_v_i: np.ndarray, dt: float):
        """
        Updates the state of the edge from the given current and time step.

        This method is called after each simulation cycle in the `Circuit.simulate` function.
        If the simulation algorithm was able to determine a proper voltage and current,
        this method is called to set those calculated values as the new "truth" for the
        next time step.

        By default, this appends the calculated voltages (and therefore also currents)
        to the `v` and `i` fields. Be sure when you override this function to also
        call the base function via `super().update(i, dt)`.

        :param i:
            New current (measured in Ampere).
        :param coupled_v_i:
            The voltages and currents of all other coupled edges (measured in Volts and Amperes respectively).
        :param dt:
            New time step (measured in seconds).
        """
        self.v.append(self.get_voltage(coupled_v_i, dt))
        self.i.append(i)
