################################################################################
# © Copyright 2021-2022 Zapata Computing Inc.
################################################################################
import operator
from functools import reduce, singledispatch
from itertools import groupby
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
import sympy

from . import _gates, _operations


def _circuit_size_by_operations(operations):
    return (
        0
        if not operations
        else max(
            qubit_index
            for operation in operations
            for qubit_index in operation.qubit_indices
        )
        + 1
    )


def _operation_uses_custom_gate(operation):
    return isinstance(operation.gate, _gates.MatrixFactoryGate) and isinstance(
        operation.gate.matrix_factory, _gates.CustomGateMatrixFactory
    )


class Circuit:
    """orquestra representation of a quantum circuit.

    See `help(orquestra.quantum.circuits)` for usage guide.
    """

    def __init__(
        self,
        operations: Optional[Iterable[_operations.Operation]] = None,
        n_qubits: Optional[int] = None,
    ):
        self._operations = list(operations) if operations is not None else []

        if n_qubits:
            cast_n_qubits = int(n_qubits)

            if n_qubits != cast_n_qubits:
                raise ValueError("Non-integer value passed.")

            if cast_n_qubits <= 0:
                raise ValueError("Non-positive value passed.")

            self._n_qubits = n_qubits
        else:
            self._n_qubits = _circuit_size_by_operations(self._operations)

    @property
    def operations(self):
        """Sequence of quantum gates to apply to qubits in this circuit."""
        return self._operations

    @property
    def n_qubits(self) -> int:
        """Number of qubits in this circuit.
        Not every qubit has to be used by a gate.
        """
        return self._n_qubits

    @property
    def free_symbols(self) -> List[sympy.Symbol]:
        """Set of all the sympy symbols used as params of gates in the circuit.
        The output list is sorted based on the order of appearance
        in `self._operations`."""
        seen_symbols = set()
        symbols_sequence = []
        for operation in self._operations:
            for symbol in operation.free_symbols:
                if symbol not in seen_symbols:
                    seen_symbols.add(symbol)
                    symbols_sequence.append(symbol)

        return symbols_sequence

    def __eq__(self, other: object):
        if not isinstance(other, type(self)):
            return NotImplemented

        if self.n_qubits != other.n_qubits:
            return False

        return list(self.operations) == list(other.operations)

    def __add__(self, other: Union["Circuit", _gates.GateOperation]):
        return _append_to_circuit(other, self)

    def collect_custom_gate_definitions(self) -> Iterable[_gates.CustomGateDefinition]:
        custom_gate_definitions = (
            operation.gate.matrix_factory.gate_definition
            for operation in self.operations
            if _operation_uses_custom_gate(operation)
        )
        unique_operation_dict = {}
        for gate_def in custom_gate_definitions:
            if gate_def.gate_name not in unique_operation_dict:
                unique_operation_dict[gate_def.gate_name] = gate_def
            elif unique_operation_dict[gate_def.gate_name] != gate_def:
                raise ValueError(
                    "Different gate definitions with the same name exist: "
                    f"{gate_def.gate_name}."
                )
        return sorted(
            unique_operation_dict.values(), key=operator.attrgetter("gate_name")
        )

    def to_unitary(self) -> Union[np.ndarray, sympy.Matrix]:
        """Create a unitary matrix describing Circuit's action.

        For performance reasons, this method will construct numpy matrix if circuit does
        not have free parameters, and a sympy matrix otherwise.
        """
        # The `reversed` iterator reflects the fact the matrices are multiplied
        # when composing linear operations (i.e. first operation is the rightmost).
        lifted_matrices = [
            op.lifted_matrix(self.n_qubits) for op in reversed(self.operations)
        ]
        return reduce(operator.matmul, lifted_matrices)

    def bind(self, symbols_map: Dict[sympy.Symbol, Any]):
        """Create a copy of the current circuit with the parameters of each gate bound
        to the values provided in the input symbols map.

        Args:
            symbols_map: A map of the symbols/gate parameters to new values
        """
        return type(self)(
            operations=[op.bind(symbols_map) for op in self.operations],
            n_qubits=self.n_qubits,
        )

    def __repr__(self):
        return (
            f"{type(self).__name__}"
            f"(operations=[{', '.join(map(str, self.operations))}], "
            f"n_qubits={self.n_qubits})"
        )

    def inverse(self) -> "Circuit":
        """Create a circuit that, when applied to the current circuit, will
        produce the identity circuit."""
        assert all(isinstance(op, _gates.GateOperation) for op in self.operations)
        try:
            return type(self)(
                operations=[
                    op.gate.dagger(*op.qubit_indices)
                    for op in reversed(self.operations)
                ],
                n_qubits=self.n_qubits,
            )
        except AttributeError as e:
            raise AttributeError(
                "Inverse is not implemented for this circuit type,"
                " since there are operators in it without the `dagger` method."
            ) from e


@singledispatch
def _append_to_circuit(other, circuit: Circuit):
    raise NotImplementedError()


@_append_to_circuit.register
def _append_operation(other: _gates.GateOperation, circuit: Circuit):
    n_qubits_by_operation = max(other.qubit_indices) + 1
    return type(circuit)(
        operations=[*circuit.operations, other],
        n_qubits=max(circuit.n_qubits, n_qubits_by_operation),
    )


@_append_to_circuit.register
def _append_circuit(other: Circuit, circuit: Circuit):
    return type(circuit)(
        operations=[*circuit.operations, *other.operations],
        n_qubits=max(circuit.n_qubits, other.n_qubits),
    )


def split_circuit(
    circuit: Circuit, predicate: Callable[[_operations.Operation], bool]
) -> Iterable[Tuple[bool, Circuit]]:
    """Split circuit into subcircuits for which predicate on all operation is constant.

    Args:
        circuit: a circuit to be split
        predicate: function assigning boolean value to each operation, its values
          are used for grouping operations belonging to the same subcircuits.
    Returns:
        An iterable of tuples of the form (x, subcircuit) s.t.:
        - predicate(operation) == x for every operation in subcircuit.operations
        - for two consecutive tuples (x1, subcircuit1), (x2, subcircuit2)
          x1 != x2 (i.e. consecutive chunks differ in the predicate value),
        - operations in subcircuits follow the same order as in original circuit
        - all subcircuits have the same number of qubits equal to `circuit.n_qubits`.
    """
    n_qubits = circuit.n_qubits
    for predicate_value, operations in groupby(circuit.operations, predicate):
        yield predicate_value, Circuit(operations, n_qubits=n_qubits)
