from __future__ import annotations

import copy
import numpy as np

from qiskit import QuantumRegister, AncillaRegister, ClassicalRegister, QuantumCircuit
from qiskit.circuit import Bit, Measure
from qiskit.circuit.quantumcircuitdata import QuantumCircuitData
from qiskit.circuit.classical import expr
from qiskit.circuit.library import RXGate, RYGate, RZGate, RXXGate, RYYGate, RZZGate
from qiskit.quantum_info import Statevector, DensityMatrix, Pauli, partial_trace, state_fidelity
from qiskit_addon_utils.slicing import slice_by_depth
from qiskit.transpiler.passes import SolovayKitaev
from qiskit.synthesis import generate_basic_approximations

from qiskit.transpiler import PassManager

from .Transpilation.ClearQEC import ClearQEC
from .Transpilation.UnBox import UnBox

from qiskit import _numpy_compat
from qiskit.exceptions import QiskitError

from typing import TYPE_CHECKING
from typing import Iterable

class LogicalCircuit(QuantumCircuit):
    """
    Core LogicalQ representation of a logical quantum circuit.
    """

    def __init__(
        self,
        n_logical_qubits: int,
        label: Iterable[int],
        stabilizer_tableau: Iterable[str],
        name: str = None,
    ):
        # Quantum error correcting code preparation
        self.n_logical_qubits = n_logical_qubits

        self.stabilizer_tableau = stabilizer_tableau
        self.n_stabilizers = len(self.stabilizer_tableau)

        self.label = label
        self.n, self.k, self.d = label
        self.n_physical_qubits = self.n

        if any([len(stabilizer) != self.n for stabilizer in self.stabilizer_tableau]):
            raise ValueError(f"Stabilizer lengths do not all equal the code label n ({self.n})")

        # @TODO - obtain an exact estimate for the number of ancilla qubits
        self.n_ancilla_qubits = self.n_stabilizers//2
        self.n_measure_qubits = self.n_ancilla_qubits

        self.flagged_stabilizers_1 = []
        self.flagged_stabilizers_2 = []
        self.x_stabilizers = []
        self.z_stabilizers = []

        # Generate the code, including its stabilizer groups
        self.generate_code()
        self.group_stabilizers()

        # Keep track of where QED/QEC cycles are located in the circuit
        self.qed_cycle_indices_initial = {}
        self.qed_cycle_indices_final = {}
        self.qec_cycle_indices_initial = {}
        self.qec_cycle_indices_final = {}

        self.logical_qregs = []
        self.ancilla_qregs = []
        self.logical_op_qregs = []
        self.enc_verif_cregs = []
        self.curr_syndrome_cregs = []
        self.prev_syndrome_cregs = []
        self.flagged_syndrome_diff_cregs = []
        self.unflagged_syndrome_diff_cregs = []
        self.pauli_frame_cregs = []
        self.logical_op_meas_cregs = []
        self.final_measurement_cregs = []

        self.qreg_lists = [
            self.logical_qregs,
            self.ancilla_qregs,
            self.logical_op_qregs,
        ]
        self.creg_lists = [
            self.enc_verif_cregs,
            self.curr_syndrome_cregs,
            self.prev_syndrome_cregs,
            self.flagged_syndrome_diff_cregs,
            self.unflagged_syndrome_diff_cregs,
            self.pauli_frame_cregs,
            self.logical_op_meas_cregs,
            self.final_measurement_cregs,
        ]

        # The underlying (empty) QuantumCircuit is generated by first calling super()...
        super().__init__(name=name)
        # ...then adding the logical qubits
        self.add_logical_qubits(self.n_logical_qubits)

        # Also add a classical measurement output register at the end
        self.output_creg = ClassicalRegister(self.n_logical_qubits, name="coutput")
        super().add_register(self.output_creg)

        # Create setter qreg for purpose of setting classical bits dynamically
        # @TODO - find alternative, possibly by implementing upstream
        self.cbit_setter_qreg = QuantumRegister(2, name="qsetter")
        self.add_register(self.cbit_setter_qreg)
        super().x(self.cbit_setter_qreg[1])

        self.data_without_qec = copy.deepcopy(self.data)
        self.data_without_qed = copy.deepcopy(self.data)

    # @TODO - this completely ignores QEC (besides encoding), do we want to have some sort of default QEC behavior?
    @classmethod
    def from_physical_circuit(
        cls,
        physical_circuit: QuantumCircuit,
        label: Iterable[int],
        stabilizer_tableau: Iterable[str],
        name: str = None,
        max_iterations: int = 1
    ) -> LogicalCircuit:
        """Construct a LogicalCircuit from a physical qiskit circuit.
        
        Args:
            physical_circuit: The QuantumCircuit to construct from.
            label: The QECC label.
            stabilizer_tableau: The set of stabilizers for the QECC.
            name: An optional name for the circuit.
            max_iterations: Number of times, to attempt to encod qubits.
        
        Returns:
            The LogicalCircuit constructed from the physical.
        """
        logical_circuit = cls(physical_circuit.num_qubits, label, stabilizer_tableau, name)

        logical_circuit.encode(*list(range(physical_circuit.num_qubits)), max_iterations=max_iterations)

        for i in range(len(physical_circuit.data)):
            circuit_instruction = physical_circuit.data[i]

            logical_circuit.append(circuit_instruction)

        return logical_circuit

    def add_logical_qubits(
        self,
        logical_qubit_count: int
    ):
        """Add logical qubit(s) to the LogicalCircuit.

        Args:
            logical_qubit_count: The number of logical qubits to add.
        """
        current_logical_qubit_count = len(self.logical_qregs)

        # @TODO - refactor to use LogicalQubit
        for i in range(current_logical_qubit_count, current_logical_qubit_count + logical_qubit_count):
            # Physical qubits for logical qubit
            logical_qreg_i = QuantumRegister(self.n_physical_qubits, name=f"qlog{i}")
            # Ancilla qubits needed for measurements
            ancilla_qreg_i = AncillaRegister(self.n_ancilla_qubits, name=f"qanc{i}")
            # Ancilla qubits needed for logical operations
            logical_op_qreg_i = AncillaRegister(2, name=f"qlogical_op{i}")
            # Classical bits needed for encoding verification
            enc_verif_creg_i = ClassicalRegister(1, name=f"cenc_verif{i}")
            # Classical bits needed for measurements
            curr_syndrome_creg_i = ClassicalRegister(self.n_measure_qubits, name=f"ccurr_syndrome{i}")
            # Classical bits needed for previous syndrome measurements
            prev_syndrome_creg_i = ClassicalRegister(self.n_stabilizers, name=f"cprev_syndrome{i}")
            # Classical bits needed for flagged syndrome difference measurements
            flagged_syndrome_diff_creg_i = ClassicalRegister(self.n_stabilizers, name=f"cflagged_syndrome_diff{i}")
            # Classical bits needed for unflagged syndrome difference measurements
            unflagged_syndrome_diff_creg_i = ClassicalRegister(self.n_stabilizers, name=f"cunflagged_syndrome_diff{i}")
            # Classical bits needed to track the Pauli Frame
            pauli_frame_creg_i = ClassicalRegister(2, name=f"cpauli_frame{i}")
            # Classical bits needed to take measurements of logical operation qubits
            logical_op_meas_creg_i = ClassicalRegister(2, name=f"clogical_op_meas{i}")
            # Classical bits needed to take measurements of the final state of the logical qubit
            final_measurement_creg_i = ClassicalRegister(self.n_physical_qubits, name=f"cfinal_meas{i}")

            # Add new registers to storage lists
            self.logical_qregs.append(logical_qreg_i)
            self.ancilla_qregs.append(ancilla_qreg_i)
            self.logical_op_qregs.append(logical_op_qreg_i)
            self.enc_verif_cregs.append(enc_verif_creg_i)
            self.curr_syndrome_cregs.append(curr_syndrome_creg_i)
            self.prev_syndrome_cregs.append(prev_syndrome_creg_i)
            self.flagged_syndrome_diff_cregs.append(flagged_syndrome_diff_creg_i)
            self.unflagged_syndrome_diff_cregs.append(unflagged_syndrome_diff_creg_i)
            self.pauli_frame_cregs.append(pauli_frame_creg_i)
            self.logical_op_meas_cregs.append(logical_op_meas_creg_i)
            self.final_measurement_cregs.append(final_measurement_creg_i)

            # Add new registers to quantum circuit
            super().add_register(logical_qreg_i)
            super().add_register(ancilla_qreg_i)
            super().add_register(logical_op_qreg_i)
            super().add_register(enc_verif_creg_i)
            super().add_register(curr_syndrome_creg_i)
            super().add_register(prev_syndrome_creg_i)
            super().add_register(flagged_syndrome_diff_creg_i)
            super().add_register(unflagged_syndrome_diff_creg_i)
            super().add_register(pauli_frame_creg_i)
            super().add_register(logical_op_meas_creg_i)
            super().add_register(final_measurement_creg_i)

            # QEC cycle indices
            self.qec_cycle_indices_initial[i] = []
            self.qec_cycle_indices_final[i] = []

    ####################################
    ##### Quantum error correction #####
    ####################################

    def group_stabilizers(self):
        """Generate the stabilizers for the code.
        """
        # @TODO - determine how stabilizers are generally selected for flagged measurements
        #       - the below is a heuristic which happens to work for the Steane code and potentially all CSS codes, but maybe not all stabilizer codes in general

        # Take the middle k stabilizers
        k = self.n_stabilizers//2
        self.flagged_stabilizers_1 = [s for s in range(self.n_stabilizers) if s < k - k//2 - 1 or s > k + k//2 - 1]
        self.flagged_stabilizers_2 = list(set(range(self.n_stabilizers)) - set(self.flagged_stabilizers_1))

        for i in range(self.n_stabilizers):
            if 'X' in self.stabilizer_tableau[i]:
                self.x_stabilizers.append(i)
            if 'Z' in self.stabilizer_tableau[i]:
                self.z_stabilizers.append(i)

    def generate_code(self):
        """Generate the encoding circuit and logical operators for the selected tableau.
        """
        m = len(self.stabilizer_tableau)

        # Step 1: Assemble generator matrix
        G = np.zeros((2, m, self.n))
        for i, stabilizer in enumerate(self.stabilizer_tableau):
            for j, pauli_j in enumerate(stabilizer):
                if pauli_j == "X":
                    G[0, i, j] = 1
                elif pauli_j == "Z":
                    G[1, i, j] = 1
                elif pauli_j == "Y":
                    G[:, i, j] = 1

        # Step 2: Perform Gaussian reduction in base 2
        row = 0
        for col in range(self.n):
            pivot_row = None
            for i in range(row, m):
                if G[0, i, col] == 1:
                    pivot_row = i
                    break

            if pivot_row is None:
                continue

            G[:, [row, pivot_row]] = G[:, [pivot_row, row]]

            # Flip any other rows with a "1" in the same column
            for i in range(m):
                if i != row and G[0, i, col] == 1:
                    G[:, i] = G[:, i].astype(int) ^ G[:, row].astype(int)
                    # G[0, i] = G[0, i].astype(int) ^ G[0, row].astype(int)
                    # G[1, i] = G[1, i].astype(int) ^ G[1, row].astype(int)

            # Move to the next row, if we haven't reached the end of the matrix
            row += 1
            if row >= m:
                break

        r = np.linalg.matrix_rank(G[0])

        E = np.copy(G[:, r:, r:])
        row = 0
        for col in range(self.n-r):
            pivot_row = None
            for i in range(row, m-r):
                if E[1, i, col] == 1:
                    pivot_row = i
                    break

            if pivot_row is None:
                continue

            E[:, [row, pivot_row]] = E[:, [pivot_row, row]]
            G[:, [r+row, r+pivot_row]] = G[:, [r+pivot_row, r+row]]

            # Flip any other rows with a "1" in the same column
            for i in range(m-r):
                if i != row and E[1, i, col] == 1:
                    E[:, i] = E[:, i].astype(int) ^ E[:, row].astype(int)
                    G[:, r+i] = G[:, r+i].astype(int) ^ G[:, r+row].astype(int)

            # Move to the next row, if we haven't reached the end of the matrix
            row += 1
            if row >= m:
                break

        # Since G is in RREF, a pivot row is also a pivot column, so find the pivot columns and move them forward
        pivot_indices = []
        for row in G[0]:
            if 1 in row:
                pivot_indices.append(int(np.where(row == 1)[0][0]))
        
        for diagonal_index, pivot_index in enumerate(pivot_indices):
            if pivot_index > -1:
                G[:, :, [diagonal_index, pivot_index]] = G[:, :, [pivot_index, diagonal_index]]

        self.G = G

        # Step 3: Construct logical operators using Pauli vector representations due to Gottesmann (1997)
        r = np.linalg.matrix_rank(self.G[0])
        A_2 = self.G[0, 0:r, m:self.n] # r x k
        C_1 = self.G[1, 0:r, r:m] # r x m-r
        C_2 = self.G[1, 0:r, m:self.n] # r x k
        E_2 = self.G[1, r:m, m:self.n] # m-r x k

        self.LogicalXVector = np.block([
            [[np.zeros((self.k, r)), E_2.T,                        np.eye(self.k, self.k)    ]],
            [[E_2.T @ C_1.T + C_2.T,      np.zeros((self.k, m-r)), np.zeros((self.k, self.k))]]
        ])

        # Create Logical X circuit corresponding to X's and Z's at 1's in Pauli vector
        self.LogicalXCircuit = QuantumCircuit(self.n, name="logical.logicalop.x.transversal:$\\hat{X}_{L}$")
        for i in range(self.k):
            # X part
            for q, bit in enumerate(self.LogicalXVector[0][i]):
                if bit == 1:
                    self.LogicalXCircuit.x(q)
            # Z part
            for q, bit in enumerate(self.LogicalXVector[1][i]):
                if bit == 1:
                    self.LogicalXCircuit.z(q)

        self.LogicalZVector = np.block([
            [[np.zeros((self.k, r)), np.zeros((self.k, m-r)), np.zeros((self.k, self.k))]],
            [[A_2.T,                 np.zeros((self.k, m-r)), np.eye(self.k, self.k)    ]]
        ])

        # Create Logical Z circuit corresponding to X's and Z's at 1's in Pauli vector
        self.LogicalZCircuit = QuantumCircuit(self.n, name="logical.logicalop.z.transversal:$\\hat{Z}_{L}$")
        for i in range(self.k):
            # X part
            for q, bit in enumerate(self.LogicalZVector[0][i]):
                if bit == 1:
                    self.LogicalZCircuit.x(q)
            # Z part
            for q, bit in enumerate(self.LogicalZVector[1][i]):
                if bit == 1:
                    self.LogicalZCircuit.z(q)

        self.LogicalYCircuit = self.LogicalXCircuit.compose(self.LogicalZCircuit)
        self.LogicalYCircuit.name = "logical.logicalop.y.transversal:$\\hat{Y}_{L}$"

        self.PhysicalToLogicalCXCircuit = self.LogicalXCircuit.control(1, label="logical.logicalop.cx.transversal:$\\hat{CX}_{L}$")
        self.PhysicalToLogicalCCXCircuit = self.LogicalXCircuit.control(2, label="logical.logicalop.ccx.transversal:$\\hat{CCX}_{L}$")
        self.PhysicalToLogicalCZCircuit = self.LogicalZCircuit.control(1, label="logical.logicalop.cz.transversal:$\\hat{CZ}_{L}$")
        self.PhysicalToLogicalCCZCircuit = self.LogicalZCircuit.control(2, label="logical.logicalop.ccz.transversal:$\\hat{CCZ}_{L}$")
        self.PhysicalToLogicalCXCircuit.name = "logical.logicalop.cx.transversal:$\\hat{CX}_{L}$"
        self.PhysicalToLogicalCCXCircuit.name = "logical.logicalop.ccx.transversal:$\\hat{CCX}_{L}$"
        self.PhysicalToLogicalCZCircuit.name = "logical.logicalop.cz.transversal:$\\hat{CZ}_{L}$"
        self.PhysicalToLogicalCCZCircuit.name = "logical.logicalop.ccz.transversal:$\\hat{CCZ}_{L}$"

        # Create Logical H circuit using Childs and Wiebe's linear combination of unitaries method
        self.LogicalHCircuit_LCU = QuantumCircuit(self.n + 1)
        self.LogicalHCircuit_LCU.h(self.n)
        self.LogicalHCircuit_LCU.compose(self.PhysicalToLogicalCXCircuit, [self.n, *list(range(self.n))], inplace=True)
        self.LogicalHCircuit_LCU.x(self.n)
        self.LogicalHCircuit_LCU.compose(self.PhysicalToLogicalCZCircuit, [self.n, *list(range(self.n))], inplace=True)
        self.LogicalHCircuit_LCU.x(self.n)
        self.LogicalHCircuit_LCU.h(self.n)

        # Creates Logical H circuit using coherent feedback
        self.LogicalHCircuit_CF = QuantumCircuit(self.n + 1)
        self.LogicalHCircuit_CF.h(self.n)
        self.LogicalHCircuit_CF.compose(self.PhysicalToLogicalCXCircuit, [self.LogicalHCircuit_CF.qubits[self.n]] + self.LogicalHCircuit_CF.qubits[:self.n], inplace=True)
        self.LogicalHCircuit_CF.compose(self.PhysicalToLogicalCZCircuit, [self.LogicalHCircuit_CF.qubits[self.n]] + self.LogicalHCircuit_CF.qubits[:self.n], inplace=True)
        self.LogicalHCircuit_CF.h(self.n)
        self.LogicalHCircuit_CF.compose(self.PhysicalToLogicalCXCircuit, [self.LogicalHCircuit_CF.qubits[self.n]] + self.LogicalHCircuit_CF.qubits[:self.n], inplace=True)
        self.LogicalHCircuit_CF.x(self.n)
        self.LogicalHCircuit_CF.compose(self.PhysicalToLogicalCZCircuit, [self.LogicalHCircuit_CF.qubits[self.n]] + self.LogicalHCircuit_CF.qubits[:self.n], inplace=True)
        self.LogicalHCircuit_CF.h(self.n)

        # Creates Logical S circuit using coherent feedback
        self.LogicalSCircuit_CF = QuantumCircuit(self.n + 1)
        self.LogicalSCircuit_CF.h(self.n)
        self.LogicalSCircuit_CF.s(self.n)
        self.LogicalSCircuit_CF.h(self.n)
        self.LogicalSCircuit_CF.compose(self.PhysicalToLogicalCZCircuit, [self.LogicalSCircuit_CF.qubits[self.n]] + self.LogicalSCircuit_CF.qubits[:self.n], inplace=True)
        self.LogicalSCircuit_CF.h(self.n)
        self.LogicalSCircuit_CF.compose(self.PhysicalToLogicalCZCircuit, [self.LogicalSCircuit_CF.qubits[self.n]] + self.LogicalSCircuit_CF.qubits[:self.n], inplace=True)
        self.LogicalSCircuit_CF.sdg(self.n)
        self.LogicalSCircuit_CF.h(self.n)

        # Creates Logical S^dagger circuit using coherent feedback
        self.LogicalSdgCircuit_CF = QuantumCircuit(self.n + 1)
        self.LogicalSdgCircuit_CF.h(self.n)
        self.LogicalSdgCircuit_CF.sdg(self.n)
        self.LogicalSdgCircuit_CF.h(self.n)
        self.LogicalSdgCircuit_CF.compose(self.PhysicalToLogicalCZCircuit, [self.LogicalSdgCircuit_CF.qubits[self.n]] + self.LogicalSdgCircuit_CF.qubits[:self.n], inplace=True)
        self.LogicalSdgCircuit_CF.h(self.n)
        self.LogicalSdgCircuit_CF.compose(self.PhysicalToLogicalCZCircuit, [self.LogicalSdgCircuit_CF.qubits[self.n]] + self.LogicalSdgCircuit_CF.qubits[:self.n], inplace=True)
        self.LogicalSdgCircuit_CF.s(self.n)
        self.LogicalSdgCircuit_CF.h(self.n)

        # Creates Logical T circuit using coherent feedback
        self.LogicalTCircuit_CF = QuantumCircuit(self.n + 2)
        self.LogicalTCircuit_CF.h(self.n)
        self.LogicalTCircuit_CF.h(self.n + 1)
        self.LogicalTCircuit_CF.t(self.n)
        self.LogicalTCircuit_CF.s(self.n + 1)
        self.LogicalTCircuit_CF.h(self.n)
        self.LogicalTCircuit_CF.compose(self.PhysicalToLogicalCZCircuit, [self.LogicalTCircuit_CF.qubits[self.n]] + self.LogicalTCircuit_CF.qubits[:self.n], inplace=True)
        self.LogicalTCircuit_CF.h(self.n)
        self.LogicalTCircuit_CF.h(self.n + 1)
        self.LogicalTCircuit_CF.compose(self.PhysicalToLogicalCCZCircuit, self.LogicalTCircuit_CF.qubits[self.n:self.n+2] + self.LogicalTCircuit_CF.qubits[:self.n], inplace=True)
        self.LogicalTCircuit_CF.h(self.n + 1)
        self.LogicalTCircuit_CF.compose(self.PhysicalToLogicalCCZCircuit, self.LogicalTCircuit_CF.qubits[self.n:self.n+2] + self.LogicalTCircuit_CF.qubits[:self.n], inplace=True)
        self.LogicalTCircuit_CF.tdg(self.n)
        self.LogicalTCircuit_CF.sdg(self.n + 1)
        self.LogicalTCircuit_CF.h(self.n)
        self.LogicalTCircuit_CF.h(self.n + 1)

        # Creates Logical T^dagger circuit using coherent feedback
        self.LogicalTdgCircuit_CF = QuantumCircuit(self.n + 2)
        self.LogicalTdgCircuit_CF.h(self.n)
        self.LogicalTdgCircuit_CF.h(self.n + 1)
        self.LogicalTdgCircuit_CF.tdg(self.n)
        self.LogicalTdgCircuit_CF.sdg(self.n + 1)
        self.LogicalTdgCircuit_CF.h(self.n)
        self.LogicalTdgCircuit_CF.compose(self.PhysicalToLogicalCZCircuit, [self.LogicalTdgCircuit_CF.qubits[self.n]] + self.LogicalTdgCircuit_CF.qubits[:self.n], inplace=True)
        self.LogicalTdgCircuit_CF.h(self.n)
        self.LogicalTdgCircuit_CF.h(self.n + 1)
        self.LogicalTdgCircuit_CF.compose(self.PhysicalToLogicalCCZCircuit, self.LogicalTdgCircuit_CF.qubits[self.n:self.n+2] + self.LogicalTdgCircuit_CF.qubits[:self.n], inplace=True)
        self.LogicalTdgCircuit_CF.h(self.n + 1)
        self.LogicalTdgCircuit_CF.compose(self.PhysicalToLogicalCCZCircuit, self.LogicalTdgCircuit_CF.qubits[self.n:self.n+2] + self.LogicalTdgCircuit_CF.qubits[:self.n], inplace=True)
        self.LogicalTdgCircuit_CF.t(self.n)
        self.LogicalTdgCircuit_CF.s(self.n + 1)
        self.LogicalTdgCircuit_CF.h(self.n)
        self.LogicalTdgCircuit_CF.h(self.n + 1)

        # @TODO - Logical CX

        # Step 4: Apply the respective stabilizers
        self.encoding_circuit = QuantumCircuit(self.n)
        for i in range(self.k):
            for j in range(r, self.n-self.k):
                if self.LogicalXVector[0, i, j]:
                    self.encoding_circuit.cx(self.n-self.k+i, j)

        for i in range(r):
            self.encoding_circuit.h(i)
            for j in range(self.n):
                if i != j:
                    if self.G[0, i, j] and self.G[1, i, j]:
                        self.encoding_circuit.cx(i, j)
                        self.encoding_circuit.cz(i, j)
                    elif self.G[0, i, j]:
                        self.encoding_circuit.cx(i, j)
                    elif self.G[1, i, j]:
                        self.encoding_circuit.cz(i, j)

        self.encoding_gate = self.encoding_circuit.to_gate(label="$U_{enc}$")

    def encode(
        self,
        *qubits: Iterable[int],
        max_iterations: int = 1,
        initial_states: Iterable[list[int]] | None = None
    ) -> bool:
        """
        Prepare logical qubit(s) in the specified initial state.

        Args:
            qubits: Qubits to encode.
            max_iterations: Maximum number of times, to try to encode data.
            initial_states: Initial states to encode in each qubit.
        
        Returns:
            :py:type:`bool`: True, always
        """

        if self.encoding_circuit is None:
            raise RuntimeError("LogicalCircuit code has not been properly constructed (missing encoding circuit)")

        if qubits is None:
            qubits = list(range(self.n_logical_qubits))
        elif (hasattr(qubits, "__iter__") and len(qubits) == 0):
            raise ValueError("No qubits specified for logical state encoding")
        else:
            if len(qubits) > 0 and hasattr(qubits[0], "__iter__"):
                # Double unwrapping in case qubits is actually a list of lists
                qubits = [qj for qi in qubits for qj in qi]

                # Check for duplicates - we could just apply set, but a common misconception is that you can call
                # LogicalCircuit.encode([0], [0]) to encode the 0th qubit in the logical 0 state, so we warn the user instead
                if len(set(qubits)) < len(qubits):
                    raise ValueError("Qubits input contains duplicate values - if you are pass a list of initial states, you must specify it by the keyword initial_states")
            else:
                # Simple list conversion to guarantee type
                qubits = list(qubits)

        if initial_states is None:
            initial_states = [0] * len(qubits)

        if initial_states is None or len(qubits) != len(initial_states):
            raise ValueError("Number of qubits should equal number of initial states if initial states are provided")

        for q, init_state in zip(qubits, initial_states):
            with self.box(label="logical.qec.encode:$\\hat U_{enc}$"):
                # Preliminary physical qubit reset
                super().reset(self.logical_qregs[q])

                # Initial encoding
                super().compose(self.encoding_circuit, self.logical_qregs[q], inplace=True)

                with self.box(label="logical.qec.encoding_verification:$\\hat U_{enc,verif}$"):
                    # CNOT from physical qubits to ancilla(e)
                    super().cx(self.logical_qregs[q][1], self.ancilla_qregs[q][0])
                    super().cx(self.logical_qregs[q][3], self.ancilla_qregs[q][0])
                    super().cx(self.logical_qregs[q][5], self.ancilla_qregs[q][0])

                    # Measure ancilla(e)
                    super().append(Measure(), [self.ancilla_qregs[q][0]], [self.enc_verif_cregs[q][0]], copy=False)

                    for _ in range(max_iterations - 1):
                        # If the ancilla stores a 1, reset the entire logical qubit and redo
                        with super().if_test((self.enc_verif_cregs[q][0], 1)) as _else:
                            super().reset(self.logical_qregs[q])

                            # Initial encoding
                            super().compose(self.encoding_circuit, self.logical_qregs[q], inplace=True)

                            # CNOT from (Z1 Z3 Z5) to ancilla
                            super().cx(self.logical_qregs[q][1], self.ancilla_qregs[q][0])
                            super().cx(self.logical_qregs[q][3], self.ancilla_qregs[q][0])
                            super().cx(self.logical_qregs[q][5], self.ancilla_qregs[q][0])

                            # Measure ancilla
                            super().append(Measure(), [self.ancilla_qregs[q][0]], [self.enc_verif_cregs[q][0]], copy=False)
                        with _else:
                            pass

                    # Reset ancilla qubit
                    super().reset(self.ancilla_qregs[q][0])

                # Flip qubits if necessary
                if init_state == 1:
                    self.x(q)
                elif init_state != 0:
                    raise ValueError("Initial state should be either 0 or 1 (arbitrary statevectors not yet supported)!")

        return True

    def reset_ancillas(
        self,
        logical_qubit_indices: Iterable[int] | None = None
    ):
        """Reset all ancillas associated with specified logical qubits.

        Args:
            logical_qubit_indices: Indices of logical qubits to reset. If None, then reset all.
        """
        if logical_qubit_indices is None or len(logical_qubit_indices) == 0:
            logical_qubit_indices = list(range(self.n_logical_qubits))

        for q in logical_qubit_indices:
            self.reset(self.ancilla_qregs[q])

    def steane_flagged_circuit1(
        self,
        logical_qubit_indices: Iterable[int]
    ):
        """Measure first set of flagged syndromes for the Steane code.
        """
        for q in logical_qubit_indices:
            super().barrier()
            super().h(self.ancilla_qregs[q][0])
            super().cx(self.ancilla_qregs[q][0], self.logical_qregs[q][3])
            super().cx(self.logical_qregs[q][2], self.ancilla_qregs[q][2])
            super().cx(self.logical_qregs[q][5], self.ancilla_qregs[q][1])
            super().cx(self.ancilla_qregs[q][0], self.ancilla_qregs[q][1])
            super().cx(self.ancilla_qregs[q][0], self.logical_qregs[q][0])
            super().cx(self.logical_qregs[q][3], self.ancilla_qregs[q][2])
            super().cx(self.logical_qregs[q][4], self.ancilla_qregs[q][1])
            super().cx(self.ancilla_qregs[q][0], self.logical_qregs[q][1])
            super().cx(self.logical_qregs[q][6], self.ancilla_qregs[q][2])
            super().cx(self.logical_qregs[q][2], self.ancilla_qregs[q][1])
            super().cx(self.ancilla_qregs[q][0], self.ancilla_qregs[q][2])
            super().cx(self.ancilla_qregs[q][0], self.logical_qregs[q][2])
            super().cx(self.logical_qregs[q][5], self.ancilla_qregs[q][2])
            super().cx(self.logical_qregs[q][1], self.ancilla_qregs[q][1])
            super().h(self.ancilla_qregs[q][0])
            super().barrier()

    def steane_flagged_circuit2(
        self,
        logical_qubit_indices: Iterable[int]
    ):
        """Measure second set of flagged syndromes for the Steane code.
        """
        for q in logical_qubit_indices:
            super().barrier()
            super().h(self.ancilla_qregs[q][1])
            super().h(self.ancilla_qregs[q][2])
            super().cx(self.logical_qregs[q][3], self.ancilla_qregs[q][0])
            super().cx(self.ancilla_qregs[q][2], self.logical_qregs[q][2])
            super().cx(self.ancilla_qregs[q][1], self.logical_qregs[q][5])
            super().cx(self.ancilla_qregs[q][1], self.ancilla_qregs[q][0])
            super().cx(self.logical_qregs[q][0], self.ancilla_qregs[q][0])
            super().cx(self.ancilla_qregs[q][2], self.logical_qregs[q][3])
            super().cx(self.ancilla_qregs[q][1], self.logical_qregs[q][4])
            super().cx(self.logical_qregs[q][1], self.ancilla_qregs[q][0])
            super().cx(self.ancilla_qregs[q][2], self.logical_qregs[q][6])
            super().cx(self.ancilla_qregs[q][1], self.logical_qregs[q][2])
            super().cx(self.ancilla_qregs[q][2], self.ancilla_qregs[q][0])
            super().cx(self.logical_qregs[q][2], self.ancilla_qregs[q][0])
            super().cx(self.ancilla_qregs[q][2], self.logical_qregs[q][5])
            super().cx(self.ancilla_qregs[q][1], self.logical_qregs[q][1])
            super().h(self.ancilla_qregs[q][1])
            super().h(self.ancilla_qregs[q][2])
            super().barrier()

    def measure_stabilizers(
        self,
        logical_qubit_indices: Iterable[int] | None = None,
        stabilizer_indices: Iterable[int] | None = None
    ):
        """Measure specified stabilizers to the circuit as controlled Pauli operators.

        Args:
            logical_qubit_indices: Indices of logical qubits for which to measure stabilizers.
            stabilizer_indices: Indices of stabilizers to measure.
        """
        if logical_qubit_indices is None or len(logical_qubit_indices) == 0:
            logical_qubit_indices = list(range(self.n_logical_qubits))

        if stabilizer_indices is None or len(logical_qubit_indices) == 0:
            stabilizer_indices = list(range(self.n_stabilizers))

        for q in logical_qubit_indices:
            for s, stabilizer_index in enumerate(stabilizer_indices):

                stabilizer = self.stabilizer_tableau[stabilizer_index]
                super().h(self.ancilla_qregs[q][s])
                for p in range(self.n_physical_qubits):
                    stabilizer_pauli = Pauli(stabilizer[p])
                    if stabilizer[p] != 'I':
                        CPauliInstruction = stabilizer_pauli.to_instruction().control(1)
                        super().append(CPauliInstruction, [self.ancilla_qregs[q][s], self.logical_qregs[q][p]])
                super().h(self.ancilla_qregs[q][s])

    def measure_syndrome_diff(
        self,
        logical_qubit_indices: Iterable[int] | None = None,
        stabilizer_indices: Iterable[int] | None = None,
        flagged: bool = False,
        steane_flag_1: bool = False,
        steane_flag_2: bool = False
    ):
        """Measure flagged or unflagged syndrome differences for specified logical qubits and stabilizers.

        Args:
            logical_qubit_indices: Logical qubits for which to measure indices.
            stabilizer_indices: Stabilizers for which to measure indices.
            flagged: Whether to measure flagged or unflagged differences.
            steane_flag_1: Whether to measure the first Steane code syndrome. Takes priority over `steane_flag_2`.
            steane_flag_2: Whether to measure the second Steane code syndrome. Ignored if `steane_flag_1` is True.
        """
        if logical_qubit_indices is None or len(logical_qubit_indices) == 0:
            logical_qubit_indices = list(range(self.n_logical_qubits))

        if stabilizer_indices is None or len(stabilizer_indices) == 0:
            stabilizer_indices = list(range(self.n_stabilizers))

        for q in logical_qubit_indices:
            syndrome_diff_creg = self.flagged_syndrome_diff_cregs[q] if flagged else self.unflagged_syndrome_diff_cregs[q]

            # Apply and measure stabilizers for the desired syndrome
            if steane_flag_1:
                self.steane_flagged_circuit1(logical_qubit_indices)
            elif steane_flag_2:
                self.steane_flagged_circuit2(logical_qubit_indices)
            else:
                self.measure_stabilizers(logical_qubit_indices=[q], stabilizer_indices=stabilizer_indices)
                
            for n in range(self.n_ancilla_qubits):
                super().append(Measure(), [self.ancilla_qregs[q][n]], [self.curr_syndrome_cregs[q][n]], copy=False)

            # Determine the syndrome difference
            for n in range(len(stabilizer_indices)):
                with self.if_test(self.cbit_xor([self.curr_syndrome_cregs[q][n], self.prev_syndrome_cregs[q][stabilizer_indices[n]]])) as _else:
                    self.set_cbit(syndrome_diff_creg[stabilizer_indices[n]], 1)
                with _else:
                    self.set_cbit(syndrome_diff_creg[stabilizer_indices[n]], 0)

        self.reset_ancillas(logical_qubit_indices=logical_qubit_indices)

    def optimize_qec_cycle_indices(
        self,
        logical_qubit_indices: Iterable[int] | None = None,
        constraint_model: dict[str, float] | None = None,
        ignore_existing_qec: bool = False,
        clear_existing_qec: bool = False
    ) -> dict[int,int]:
        """Compute optimal QEC cycle indices.

        Args:
            logical_qubit_indices: Logical qubits for which to compute optimal QEC cycle indices.
            constraint_model: Constraint model, i.e., dictionary of gadget costs for the circuit.
            ignore_existing_qec: Whether to ignore QEC-related parameters in the constraint model.
            clear_existing_qec: Whether to remove existing QEC cycles.
        
        Returns:
            Dictionary with optimal QEC cycle indices for each requested qubit.
        """
        if logical_qubit_indices is None or len(logical_qubit_indices) == 0:
            logical_qubit_indices = list(range(self.n_logical_qubits))

        if clear_existing_qec and not ignore_existing_qec:
            raise ValueError("Clear existing QEC requested but not ignore existing QEC, which is likely to result in index errors because existing QEC cycles are cleared before new ones are inserted at the computed indices")

        if constraint_model is None:
            raise ValueError("A valid constraint_model input is required by optimize_qec_cycle_indices")

        # @TODO - if the user has requested that QEC be ignored, check whether there are any QEC-related parameters in the constraint_model

        # @TODO - transpile circuit into basis gates before doing scheduling
        #       - one difficulty this will bring is mapping between indices before and after transpilation

        slices = slice_by_depth(self, 1)
        depths = []
        for d, slice in enumerate(slices):
            depths.extend([d]*len(slice.data))

        new_qec_cycle_indices_initial = {}
        def compute_instruction_contributions(q, i, depth, instruction, counters, running_cost):
            # Ignore certain "trivial" operations
            if instruction.name in ["barrier"]:
                return False

            # Ignore QEC steps (e.g. encoding, QEC cycles)
            if instruction.label is not None and instruction.label.startswith("logical.qec"):
                return False

            met = False

            # If in a ControlFlowOp, loop over the instructions in the data
            if instruction.is_control_flow():
                for param in instruction.params:
                    if isinstance(param, QuantumCircuit):
                        for sub_instruction in param:
                            met = met or compute_instruction_contributions(q, i, depth, sub_instruction, counters, running_cost)

                            # Stop here for this instruction - we currently don't have special costs for ControlFlowOps in this function
                            return met

            # @TODO - handle controlled gates
            if instruction.is_controlled_gate():
                if not instruction.is_standard_gate():
                    if instruction.name.startswith("clogical.logicalop"):
                        base_gate = instruction.base_gate
                        met = met or compute_instruction_contributions(q, i, depth, base_gate, counters, running_cost)
                    else:
                        print(f"WARNING - Unrecognized controlled gate with name '{instruction.name}' and label '{instruction.label}' identified, costs may not be accurate")

            # @TODO - unsure whether this is implemented correctly
            if "circuit_depth_logical_qubit" in constraint_model.keys():
                if q in new_qec_cycle_indices_initial:
                    qec_cycle_index_prev = new_qec_cycle_indices_initial[q][-1]
                    depth_prev = depths[qec_cycle_index_prev]
                else:
                    depth_prev = 0

                if (depth - depth_prev) >= constraint_model["circuit_depth_logical_qubit"]:
                    met = True

            # Check instruction-specific criteria
            # @TODO - actually, we shouldn't have to check counters here, only once we escape the top-level instruction,
            #         because we're (currently) not going to insert QEC cycles inside wrapped instructions

            if f"num_{instruction.name}" in constraint_model.keys():
                counters[f"num_{instruction.name}"] = counters.get(f"num_{instruction.name}", 0) + 1
                met = met or counters[f"num_{instruction.name}"] >= constraint_model[f"num_{instruction.name}"]
            if f"cost_{instruction.name}" in constraint_model.keys():
                running_cost[-1] += constraint_model[f"cost_{instruction.name}"]

            if f"num_ops_{len(instruction.qubits)}q" in constraint_model.keys():
                counters[f"num_ops_{len(instruction.qubits)}q"] = counters.get(f"num_ops_{len(instruction.qubits)}q", 0) + 1
                met = met or counters[f"num_ops_{len(instruction.qubits)}q"] >= constraint_model[f"num_ops_{len(instruction.qubits)}q"]
            if f"cost_ops_{len(instruction.qubits)}q" in constraint_model.keys():
                running_cost[-1] += constraint_model[f"cost_ops_{len(instruction.qubits)}q"]

            if instruction.name in ["x", "y", "z", "h", "s", "cx", "cy", "cz"]:
                if f"cost_ops_clifford_{len(instruction.qubits)}q" in constraint_model.keys():
                    running_cost[-1] += constraint_model[f"cost_ops_clifford_{len(instruction.qubits)}q"]

                if instruction.name in ["x", "y", "z"]:
                    if f"cost_ops_pauli_{len(instruction.qubits)}q" in constraint_model.keys():
                        running_cost[-1] += constraint_model[f"cost_ops_pauli_{len(instruction.qubits)}q"]
                else:
                    if f"cost_ops_nonpauli_{len(instruction.qubits)}q" in constraint_model.keys():
                        running_cost[-1] += constraint_model[f"cost_ops_nonpauli_{len(instruction.qubits)}q"]
            else:
                if f"cost_ops_nonclifford_{len(instruction.qubits)}q" in constraint_model.keys():
                    running_cost[-1] += constraint_model[f"cost_ops_nonclifford_{len(instruction.qubits)}q"]

            # @TODO - check key "num_ops_clifford"
            # @TODO - check key "cost_ops_clifford"
            # @TODO - check key "num_ops_non_clifford"
            # @TODO - check key "cost_ops_non_clifford"
            # @TODO - check key "num_logical_ops_transversal"
            # @TODO - check key "cost_logical_ops_transversal"
            # @TODO - check key "num_logical_ops_non_transversal"
            # @TODO - check key "cost_logical_ops_non_transversal"
            # @TODO - check key "circuit_depth_logical_qubit"
            # @TODO - check key "cost_circuit_depth_logical_qubit"
            # @TODO - check key "cost_circuit_depth_logical_circuit"

            return met

        for q in logical_qubit_indices:
            # Counters track constraints which are contributed to and met separately when any one reaches its limit
            counters = {}
            # Running costs tracks collective constraints which sum over contributions from many sources, which must remain below the effective threshold
            running_cost = [0.0]

            for i, (depth, instruction) in enumerate(zip(depths, self.data)):
                # Check whether instruction involves logical qubit
                instruction_involves_logical_qubit = False
                for qreg_list in self.qreg_lists:
                    qreg = qreg_list[q]
                    if any([qubit in instruction.qubits for qubit in qreg]):
                        instruction_involves_logical_qubit = True

                    if instruction_involves_logical_qubit:
                        break

                if not instruction_involves_logical_qubit:
                    continue

                met = compute_instruction_contributions(q, i, depth, instruction, counters, running_cost)

                # print(counters, running_cost[-1], constraint_model["effective_threshold"])

                met = met or running_cost[-1] >= constraint_model["effective_threshold"]
                if met:
                    # print(f"Inserting QEC cycle on qubit {q} at index {i}, depth {depth}")
                    new_qec_cycle_indices_initial[q] = new_qec_cycle_indices_initial.get(q, []) + [i]

                    # Reset counters and running cost
                    counters = {}
                    running_cost[-1] = 0.0

        return new_qec_cycle_indices_initial

    # Insert QEC cycles at specified indices in the circuit data
    # @TODO - Extend the method to process qubit-specific indices for user-friendliness
    def insert_qec_cycles(
        self,
        logical_qubit_indices: Iterable[int] | None = None,
        qec_cycle_indices: dict[int, list] | list[list] | None = None,
        clear_existing_qec: bool = False
    ) -> tuple[QuantumCircuitData, QuantumCircuitData]:
        """Insert QEC cycles at specified indices in the circuit data.

        Args:
            logical_qubit_indices: Logical qubits for which to insert QEC cycles.
            qec_cycle_indices: Indices at which to insert QEC cycles for each logical qubit.
            clear_existing_qec: Whether to clear the existing QEC cycles.

        Returns:
            Original circuit data and data after appending QEC cycles.
        """
        # Carefully perform all checks beforehand because it's very difficult to catch errors mid-execution

        if logical_qubit_indices is None or len(logical_qubit_indices) == 0:
            if isinstance(qec_cycle_indices, dict):
                logical_qubit_indices = qec_cycle_indices.keys()
            else:
                logical_qubit_indices = list(range(self.n_logical_qubits))

        if isinstance(qec_cycle_indices, dict):
            if set(logical_qubit_indices) != set(qec_cycle_indices.keys()):
                raise ValueError("qec_cycle_indices is a dict but its set of keys does not equal the list logical_qubit_indices")
        elif isinstance(qec_cycle_indices, list):
            if len(logical_qubit_indices) != len(qec_cycle_indices):
                raise ValueError("qec_cycle_indices is a list but its length does not equal length of list logical_qubit_indices")

            qec_cycle_indices = dict(zip(logical_qubit_indices, qec_cycle_indices))

        for q, qec_cycle_indices_q in qec_cycle_indices.items():
            if not hasattr(qec_cycle_indices_q, "__iter__"):
                raise ValueError(f"QEC cycle indices input for logical qubit {q} is {type(qec_cycle_indices_q)}, list expected.")

            if any([not isinstance(index, int) for index in qec_cycle_indices_q]):
                raise ValueError(f"QEC cycle indices input for logical qubit {q} is {type(qec_cycle_indices_q)}, list expected.")

        # "Transpose" dictionary
        qec_cycle_indices_T = {}
        # For each key (logical qubit index) and value list (list of QEC cycle indices for the logical qubit)...
        for key, value_list in qec_cycle_indices.items():
            # ...and for each value (QEC cycle index) in the value list (list of QEC cycle indices for this logical qubit)...
            for value in value_list:
                # ...append the key (logical qubit index) to the transposed dictionary for this value (QEC cycle index)!
                qec_cycle_indices_T[value] = qec_cycle_indices_T.get(value, []) + [key]

        # If requested, remove QEC cycles (only full-clear supported)
        if clear_existing_qec:
            self.clear_qec_cycles()

        # Deepcopy current circuit data
        _data = copy.deepcopy(self.data)

        # Reconstruct circuit, appending QEC cycles along the way when needed
        self.data = []
        for i, instruction in enumerate(_data):
            # If current index is a key in transposed dictionary, append QEC cycles for corresponding values (logical qubit indices)
            if i in qec_cycle_indices_T:
                self.append_qec_cycle(qec_cycle_indices_T[i])

            # Append original circuit data
            self.data.append(instruction)

        return _data, self.data

    def append_qed_cycle(
        self,
        logical_qubit_indices: Iterable[int] | None = None,
        perform_flagged_syndrome_measurements: bool = True
    ) -> tuple[dict[int,int], dict[int,int]]:
        """Append a QED cycle to the end of the circuit.

        Args:
            logical_qubit_indices: Logical qubits for which to add QED cycles.
        
        Returns:
            Qubits and indices with QED cycles, before and after appending.
        """
        
        if logical_qubit_indices is None or len(logical_qubit_indices) == 0:
            logical_qubit_indices = list(range(self.n_logical_qubits))

        for q in logical_qubit_indices:
            if len(self.data_without_qed) is None:
                self.data_without_qed = copy.deepcopy(self.data)
            else:
                if len(self.qed_cycle_indices_final[q]) > 0:
                    last_qed_index_final = self.qed_cycle_indices_final[q][-1]
                else:
                    last_qed_index_final = -1

                self.data_without_qed.extend(self.data[last_qed_index_final+1:])

            # Keep track of the initial index for now, only append at the end once we know this call was successful
            index_initial = len(self.data)

            with self.box(label="logical.qed.qed_cycle:$\\hat U_{QED}$"):
                super().reset(self.ancilla_qregs[q])
                
                if perform_flagged_syndrome_measurements:
                    # Perform first flagged syndrome measurements
                    self.measure_syndrome_diff(logical_qubit_indices=[q], stabilizer_indices=self.flagged_stabilizers_1, flagged=True, steane_flag_1=True)

                    # If no change in syndrome, perform second flagged syndrome measurement
                    with self.if_test(self.cbit_and(self.flagged_syndrome_diff_cregs[q], [0]*self.flagged_syndrome_diff_cregs[q].size)) as _else:
                        self.measure_syndrome_diff(logical_qubit_indices=[q], stabilizer_indices=self.flagged_stabilizers_2, flagged=True, steane_flag_2=True)
                    with _else:
                        pass
                    
                    with self.if_test(self.cbit_and(self.flagged_syndrome_diff_cregs[q], [0]*self.flagged_syndrome_diff_cregs[q].size)) as _else:
                        pass
                    with _else:
                        # If change in syndrome, perform unflagged syndrome measurement
                        self.measure_syndrome_diff(logical_qubit_indices=[q], stabilizer_indices=self.x_stabilizers, flagged=False)
                        self.measure_syndrome_diff(logical_qubit_indices=[q], stabilizer_indices=self.z_stabilizers, flagged=False)

                        # Update previous syndrome
                        for n in range(self.n_stabilizers):
                            with self.if_test(expr.lift(self.unflagged_syndrome_diff_cregs[q][n])) as _else_inner:
                                self.cbit_not(self.prev_syndrome_cregs[q][n])
                            with _else_inner:
                                pass
                else:
                    # Perform unflagged syndrome measurements, decode, and correct
                    self.measure_syndrome_diff(logical_qubit_indices=[q], stabilizer_indices=self.x_stabilizers, flagged=False)
                    self.measure_syndrome_diff(logical_qubit_indices=[q], stabilizer_indices=self.z_stabilizers, flagged=False)

                    # Update previous syndrome
                    for n in range(self.n_stabilizers):
                        with self.if_test(expr.lift(self.unflagged_syndrome_diff_cregs[q][n])) as _else_inner:
                            self.cbit_not(self.prev_syndrome_cregs[q][n])
                        with _else_inner:
                            pass

            index_final = len(self.data)-1

            self.qed_cycle_indices_initial[q].append(index_initial)
            self.qed_cycle_indices_final[q].append(index_final)

        return self.qed_cycle_indices_initial, self.qed_cycle_indices_final
    
    def append_qec_cycle(
        self,
        logical_qubit_indices: Iterable[int] | None = None,
        perform_flagged_syndrome_measurements: bool = True
    ) -> tuple[dict[int,int], dict[int,int]]:
        """Append a QEC cycle to the end of the circuit.

        Args:
            logical_qubit_indices: Logical qubits for which to add QEC cycles.
        
        Returns:
            Qubits and indices with QEC cycles, before and after appending.
        """
        
        if logical_qubit_indices is None or len(logical_qubit_indices) == 0:
            logical_qubit_indices = list(range(self.n_logical_qubits))

        for q in logical_qubit_indices:
            if len(self.data_without_qec) is None:
                self.data_without_qec = copy.deepcopy(self.data)
            else:
                if len(self.qec_cycle_indices_final[q]) > 0:
                    last_qec_index_final = self.qec_cycle_indices_final[q][-1]
                else:
                    last_qec_index_final = -1

                self.data_without_qec.extend(self.data[last_qec_index_final+1:])

            # Keep track of the initial index for now, only append at the end once we know this call was successful
            index_initial = len(self.data)

            with self.box(label="logical.qec.qec_cycle:$\\hat U_{QEC}$"):
                super().reset(self.ancilla_qregs[q])
                
                if perform_flagged_syndrome_measurements:
                    # Perform first flagged syndrome measurements
                    self.measure_syndrome_diff(logical_qubit_indices=[q], stabilizer_indices=self.flagged_stabilizers_1, flagged=True, steane_flag_1=True)

                    # If no change in syndrome, perform second flagged syndrome measurement
                    with self.if_test(self.cbit_and(self.flagged_syndrome_diff_cregs[q], [0]*self.flagged_syndrome_diff_cregs[q].size)) as _else:
                        self.measure_syndrome_diff(logical_qubit_indices=[q], stabilizer_indices=self.flagged_stabilizers_2, flagged=True, steane_flag_2=True)
                    with _else:
                        pass
                    
                    with self.if_test(self.cbit_and(self.flagged_syndrome_diff_cregs[q], [0]*self.flagged_syndrome_diff_cregs[q].size)) as _else:
                        pass
                    with _else:
                        # If change in syndrome, perform unflagged syndrome measurement, decode, and correct
                        self.measure_syndrome_diff(logical_qubit_indices=[q], stabilizer_indices=self.x_stabilizers, flagged=False)
                        self.measure_syndrome_diff(logical_qubit_indices=[q], stabilizer_indices=self.z_stabilizers, flagged=False)

                        self.apply_decoding(logical_qubit_indices=[q], stabilizer_indices=self.x_stabilizers, with_flagged=False)
                        self.apply_decoding(logical_qubit_indices=[q], stabilizer_indices=self.z_stabilizers, with_flagged=False)
                        
                        self.apply_decoding(logical_qubit_indices=[q], stabilizer_indices=self.x_stabilizers, with_flagged=True)
                        self.apply_decoding(logical_qubit_indices=[q], stabilizer_indices=self.z_stabilizers, with_flagged=True)

                        # Update previous syndrome
                        for n in range(self.n_stabilizers):
                            with self.if_test(expr.lift(self.unflagged_syndrome_diff_cregs[q][n])) as _else_inner:
                                self.cbit_not(self.prev_syndrome_cregs[q][n])
                            with _else_inner:
                                pass
                else:
                    # Perform unflagged syndrome measurements, decode, and correct
                    self.measure_syndrome_diff(logical_qubit_indices=[q], stabilizer_indices=self.x_stabilizers, flagged=False)
                    self.measure_syndrome_diff(logical_qubit_indices=[q], stabilizer_indices=self.z_stabilizers, flagged=False)

                    self.apply_decoding(logical_qubit_indices=[q], stabilizer_indices=self.x_stabilizers, with_flagged=False)
                    self.apply_decoding(logical_qubit_indices=[q], stabilizer_indices=self.z_stabilizers, with_flagged=False)  
                            
                    # Update previous syndrome
                    for n in range(self.n_stabilizers):
                        with self.if_test(expr.lift(self.unflagged_syndrome_diff_cregs[q][n])) as _else_inner:
                            self.cbit_not(self.prev_syndrome_cregs[q][n])
                        with _else_inner:
                            pass

            index_final = len(self.data)-1

            self.qec_cycle_indices_initial[q].append(index_initial)
            self.qec_cycle_indices_final[q].append(index_final)

        return self.qec_cycle_indices_initial, self.qec_cycle_indices_final

    def clear_qec_cycles(
        self,
        logical_qubit_indices: Iterable[int] | None = None,
        qec_cycle_indices: Iterable[int ] | None = None
    ) -> NotImplementedError:
        """Clear QEC cycles (either specified or all) on specified logical qubits.

        Args:
            logical_qubit_indices: Logical qubits from which to clear QEC cycles. Clears all if :py:type:`None`.
            qec_cycle_indices: QEC cycles to clear. Clears all if :py:type:`None`
        
        Returns:
            This method is not yet implemented.
        """

        return NotImplementedError("clear_qec_cycles is not yet implemented")

    # @TODO - determine appropriate syndrome decoding mappings dynamically
    def apply_decoding(
        self,
        logical_qubit_indices: Iterable[int],
        stabilizer_indices: Iterable[int],
        with_flagged: bool
    ):
        """Decode the syndrome measurements to determine the correction to apply.

        Args:
            logical_qubit_indices: Logical qubits to apply decoding to.
            stabilizer_indices: Stabilizers for which to decode syndromes.
            with_flagged: Whether to decode with flagged or unflagged syndrome differences.
        """
        for q in logical_qubit_indices:
            syn_diff = [self.unflagged_syndrome_diff_cregs[q][x] for x in stabilizer_indices]
            # Determines index of pauli frame to be modified
            pf_ind = 0 if 'X' in self.stabilizer_tableau[stabilizer_indices[0]] else 1

            # Decoding sequence with flagged syndrome
            if with_flagged:
                flag_diff = [self.flagged_syndrome_diff_cregs[q][x] for x in stabilizer_indices]
                with super().if_test(expr.bit_and(self.cbit_and(flag_diff, [1, 0, 0]), self.cbit_and(syn_diff, [0, 1, 0]))) as _else:
                    self.cbit_not(self.pauli_frame_cregs[q][pf_ind])
                with _else:
                    pass
                with super().if_test(expr.bit_and(self.cbit_and(flag_diff, [1, 0, 0]), self.cbit_and(syn_diff, [0, 0, 1]))) as _else:
                    self.cbit_not(self.pauli_frame_cregs[q][pf_ind])
                with _else:
                    pass
                with super().if_test(expr.bit_and(self.cbit_and(flag_diff, [0, 1, 1]), self.cbit_and(syn_diff, [0, 0, 1]))) as _else:
                    self.cbit_not(self.pauli_frame_cregs[q][pf_ind])
                with _else:
                    pass

            # Unflagged decoding sequence
            else:
                with super().if_test(self.cbit_and(syn_diff, [0, 1, 0])) as _else:
                    self.cbit_not(self.pauli_frame_cregs[q][pf_ind])
                with _else:
                    pass
                with super().if_test(self.cbit_and(syn_diff, [0, 1, 1])) as _else:
                    self.cbit_not(self.pauli_frame_cregs[q][pf_ind])
                with _else:
                    pass
                with super().if_test(self.cbit_and(syn_diff, [0, 0, 1])) as _else:
                    self.cbit_not(self.pauli_frame_cregs[q][pf_ind])
                with _else:
                    pass

    def measure(
        self,
        logical_qubit_indices: Iterable[int],
        cbit_indices: Iterable[int],
        with_error_correction: bool = True
    ):
        """Measure specified qubits (in the Z basis) into classical bits.

        Args:
            logical_qubit_indices: Logical qubits to measure.
            cbit_indices: Classical bits in which to record measurements.
            with_error_correction: Whether to apply QEC.
        
        Raises:
            :class:`ValueError`: if `logical_qubit_indices` or `cbit_indices` is not an iterable or
                if number of qubits and cbits does not match. 
        """
        if not hasattr(logical_qubit_indices, "__iter__"):
            raise ValueError("Logical qubit indices must be an iterable!")

        if not hasattr(cbit_indices, "__iter__"):
            raise ValueError("Classical bit indices must be an iterable!")

        if len(logical_qubit_indices) != len(cbit_indices):
            raise ValueError("Number of qubits should equal number of classical bits")

        for q, c in zip(logical_qubit_indices, cbit_indices):
            with self.box(label="logical.qec.measure:$\\hat{M}_\\text{QEC}$"):
                # Measurement of state
                for n in range(self.n_physical_qubits):
                    super().append(Measure(), [self.logical_qregs[q][n]], [self.final_measurement_cregs[q][n]], copy=False)
                    
                # @TODO - use LogicalZVector instead
                with super().if_test(self.cbit_xor([self.final_measurement_cregs[q][x] for x in [4,5,6]])) as _else:
                    self.set_cbit(self.output_creg[c], 1)
                with _else:
                    pass
                
                if with_error_correction:
                    # Final syndrome
                    for n in range(self.n_ancilla_qubits):
                        stabilizer = self.stabilizer_tableau[self.z_stabilizers[n]]
                        s_indices = []
                        for i in range(len(stabilizer)):
                            if stabilizer[i] == 'Z':
                                s_indices.append(i)

                        with super().if_test(self.cbit_xor([self.final_measurement_cregs[q][z] for z in s_indices])) as _else:
                            self.set_cbit(self.curr_syndrome_cregs[q][n], 1)
                        with _else:
                            pass

                    # Final syndrome diff
                    for n in range(self.n_ancilla_qubits):
                        with super().if_test(self.cbit_xor([self.curr_syndrome_cregs[q][n], self.prev_syndrome_cregs[q][self.z_stabilizers[n]]])) as _else:
                            self.set_cbit(self.unflagged_syndrome_diff_cregs[q][self.z_stabilizers[n]], 1)
                        with _else:
                            self.set_cbit(self.unflagged_syndrome_diff_cregs[q][self.z_stabilizers[n]], 0)

                    # Final correction
                    self.apply_decoding([q], self.z_stabilizers, with_flagged=False)
                    with super().if_test(expr.lift(self.pauli_frame_cregs[q][1])) as _else:
                        self.cbit_not(self.output_creg[c])
                    with _else:
                        pass

    def measure_all(
        self,
        inplace: bool = True,
        with_error_correction: bool = True
    ) -> LogicalCircuit:
        """Add measurements to all qubits.

        Args:
            inplace: Whether to perform measurements on this circuit or a copy. If `False`, a new circuit is returned.
            with_error_correction: Whether to perform measurements with QEC.
        
        Returns:
            Circuit with measurements, if `inplace = False`.
        """
        if inplace:
            self.measure(range(self.n_logical_qubits), range(self.n_logical_qubits), with_error_correction=with_error_correction)
        else:
            _lqc = copy.deepcopy(self)
            _lqc.measure_all(inplace=True, with_error_correction=with_error_correction)
            return _lqc

    def remove_final_measurements(
        self,
        inplace: bool = False
    ) -> LogicalCircuit:
        """Remove final measurements from circuit.

        Args:
            inplace: Whether to remove measurements in place, i.e., for this circuit (not supported).

        Raises:
            `NotImplementedError`: if in-place measurement is attempted.
        
        Returns:
            The circuit without measurements.
        """
        if inplace:
            raise NotImplementedError("Inplace measurement removal is not supported")

        lqc_no_meas = LogicalCircuit(self.n_logical_qubits, self.label, self.stabilizer_tableau, self.name + "_no_meas")

        for circuit_instruction in self.data:
            if circuit_instruction.name != "measure":
                lqc_no_meas._append(circuit_instruction)

        return lqc_no_meas

    def get_logical_counts(
        self,
        physical_counts: Iterable[int],
        logical_qubit_indices: Iterable[int] = None
    ) -> dict[str, int]:
        """Get logical counts from physical counts.

        Args:
            physical_counts: Physical counts to convert to logical counts.
            logical_qubit_indices: Logical qubits to get counts for. If `None`, then get counts for all.

        Returns:
            Logical qubit counts.
        """
        if logical_qubit_indices is None:
            logical_qubit_indices = range(self.n_logical_qubits)

        logical_counts = {}
        for physical_outcome, physical_outcome_counts in physical_counts.items():
            logical_outcome = "".join([physical_outcome[self.n_logical_qubits-1-l] for l in logical_qubit_indices])

            logical_counts[logical_outcome] = logical_counts.get(logical_outcome, 0) + physical_outcome_counts

        return logical_counts

    ######################################
    ##### Logical quantum operations #####
    ######################################

    def h(self, *targets, method="Coherent_Feedback"):
        """
        Logical Hadamard gate
        """

        if len(targets) == 1 and hasattr(targets[0], "__iter__"):
            targets = targets[0]

        if method == "LCU":
            for t in targets:
                with self.box(label="logical.logicalop.h.lcu:$\\hat H_{L}$"):
                    super().compose(self.LogicalHCircuit_LCU, [self.logical_op_qregs[t][0]] + self.logical_qregs[t][:], inplace=True)

            # @TODO - perform resets after main operation is complete to allow for faster(?) parallel operation
            # for t in targets:
                # @TODO - determine whether extra reset is necessary at the end
                # with self.box(label="logical.logicalop.lcu"):
                    # super().reset(self.logical_op_qregs[t])
        elif method == "LCU_Corrected": 
            for t in targets:
                with self.box(label="logical.logicalop.h.lcu_corrected:$\\hat H_{L}$"):
                    # Construct circuit for implementing a Hadamard gate through the use of an ancilla
                    super().h(self.logical_op_qregs[t][0])
                    super().compose(self.LogicalXCircuit.control(1), [self.logical_op_qregs[t][0]] + self.logical_qregs[t][:], inplace=True)
                    super().compose(self.LogicalZCircuit.control(1), [self.logical_op_qregs[t][0]] + self.logical_qregs[t][:], inplace=True)
                    super().h(self.logical_op_qregs[t][0])
                    super().append(Measure(), [self.logical_op_qregs[t][0]], [self.logical_op_meas_cregs[t][0]], copy=False)
                    super().reset(self.logical_op_qregs[t][0])

                    # Corrections to apply based on ancilla measurement
                    with super().if_test((self.logical_op_meas_cregs[t][0], 1)) as else_:
                        self.x(t)
                    with else_:
                        self.z(t)

        elif method == "Coherent_Feedback":
            for t in targets:
                with self.box(label="logical.logicalop.h.coherent_feedback:$\\hat H_{L}$"):
                    super().compose(self.LogicalHCircuit_CF, self.logical_qregs[t][:] + [self.logical_op_qregs[t][0]], inplace=True)

        elif method == "Transversal_Uniform":
            for t in targets:
                with self.box(label="logical.logicalop.h.transversal_uniform:$\\hat H_{L}$"):
                    super().h(self.logical_qregs[t][:])
        else:
            raise ValueError(f"'{method}' is not a valid method for the logical Hadamard gate")

    def x(self, *targets):
        """
        Logical PauliX gate
        """

        if len(targets) == 1 and hasattr(targets[0], "__iter__"):
            targets = targets[0]

        for t in targets:
            with self.box(label="logical.logicalop.x.gottesman:$\\hat X_{L}$"):
                super().compose(self.LogicalXCircuit, self.logical_qregs[t], inplace=True)

    def y(self, *targets):
        """
        Logical PauliY gate
        """

        if len(targets) == 1 and hasattr(targets[0], "__iter__"):
            targets = targets[0]

        with self.box(label="logical.logicalop.y.derived:$\\hat Y_{L}$"):
            self.z(targets)
            self.x(targets)

    def z(self, *targets):
        """
        Logical PauliZ gate
        """

        if len(targets) == 1 and hasattr(targets[0], "__iter__"):
            targets = targets[0]

        for t in targets:
            with self.box(label="logical.logicalop.z.gottesman:$\\hat Z_{L}$"):
                super().compose(self.LogicalZCircuit, self.logical_qregs[t], inplace=True)

    def s(self, *targets, method="Coherent_Feedback"):
        """
        Logical S gate

        Definition:
        [1   0]
        [0   i]
        """

        if len(targets) == 1 and hasattr(targets[0], "__iter__"):
            targets = targets[0]

        if method == "LCU_Corrected":
            for t in targets:
                with self.box(label="logical.logicalop.s.lcu_corrected:$\\hat S_{L}$"):
                    super().h(self.logical_op_qregs[t][0])
                    super().s(self.logical_op_qregs[t][0])
                    super().h(self.logical_op_qregs[t][0])
                    super().compose(self.LogicalZCircuit.control(1), [self.logical_op_qregs[t][0]] + self.logical_qregs[t][:], inplace=True)
                    super().h(self.logical_op_qregs[t][0])
                    super().append(Measure(), [self.logical_op_qregs[t][0]], [self.logical_op_meas_cregs[t][0]], copy=False)

                    with super().if_test((self.logical_op_meas_cregs[t][0], 1)) as _else:
                        self.z(t)
                    with _else:
                        pass

                    super().reset(self.logical_op_qregs[t][0])
        elif method == "Coherent_Feedback":
            for t in targets:
                with self.box(label="logical.logicalop.s.coherent_feedback:$\\hat S_{L}$"):
                    super().compose(self.LogicalSCircuit_CF, self.logical_qregs[t][:] + [self.logical_op_qregs[t][0]], inplace=True)
        elif method == "Transversal_Uniform":
            for t in targets:
                with self.box(label="logical.logicalop.s.transversal_uniform:$\\hat S_{L}$"):
                    super().sdg(self.logical_qregs[t][:])

        else:
            raise ValueError(f"'{method}' is not a valid method for the logical S gate")

    def sdg(self, *targets, method="Coherent_Feedback"):
        """
        Logical S^dagger gate

        Definition:
        [1    0]
        [0   -i]
        """
        
        if len(targets) == 1 and hasattr(targets[0], "__iter__"):
            targets = targets[0]

        if method == "LCU_Corrected":
            for t in targets:
                with self.box(label="logical.logicalop.sdg.lcu_corrected:$\\hat{S^\\dagger}_{L}$"):
                    super().h(self.logical_op_qregs[t][0])
                    super().sdg(self.logical_op_qregs[t][0])
                    super().h(self.logical_op_qregs[t][0])
                    super().compose(self.LogicalZCircuit.control(1), [self.logical_op_qregs[t][0]] + self.logical_qregs[t][:], inplace=True)
                    super().h(self.logical_op_qregs[t][0])
                    super().append(Measure(), [self.logical_op_qregs[t][0]], [self.logical_op_meas_cregs[t][0]], copy=False)

                    with super().if_test((self.logical_op_meas_cregs[t][0], 1)) as _else:
                        self.z(t)
                    with _else:
                        pass

                    super().reset(self.logical_op_qregs[t][0])
        elif method == "Coherent_Feedback":
            for t in targets:
                with self.box(label="logical.logicalop.sdg.coherent_feedback:$\\hat{S^\\dagger}_{L}$"):
                    super().compose(self.LogicalSdgCircuit_CF, self.logical_qregs[t][:] + [self.logical_op_qregs[t][0]], inplace=True)

        elif method == "Transversal_Uniform":
            for t in targets:
                with self.box(label="logical.logicalop.sdg.transversal_uniform:$\\hat{S^\\dagger}_{L}$"):
                    super().s(self.logical_qregs[t][:])

        else:
            raise ValueError(f"'{method}' is not a valid method for the logical S^dagger gate")

    def t(self, *targets, method="Coherent_Feedback"):
        """
        Logical T gate

        Definition:
        [1    0        ]
        [0    e^(ipi/4)]
        """

        if len(targets) == 1 and hasattr(targets[0], "__iter__"):
            targets = targets[0]

        if method == "LCU_Corrected":
            for t in targets:
                with self.box(label="logical.logicalop.t.lcu_corrected:$\\hat T_{L}$"):
                    super().h(self.logical_op_qregs[t][0])
                    super().t(self.logical_op_qregs[t][0])
                    super().h(self.logical_op_qregs[t][0])
                    super().compose(self.LogicalZCircuit.control(1), [self.logical_op_qregs[t][0]] + self.logical_qregs[t][:], inplace=True)
                    super().h(self.logical_op_qregs[t][0])

                    super().append(Measure(), [self.logical_op_qregs[t][0]], [self.logical_op_meas_cregs[t][0]], copy=False)
                    super().reset(self.logical_op_qregs[t][0])

                    with super().if_test((self.logical_op_meas_cregs[t][0], 1)) as _else:
                        self.s(t, method='LCU_corrected')
                    with _else:
                        pass

        elif method == "Coherent_Feedback":
            for t in targets:
                with self.box(label="logical.logicalop.t.coherent_feedback:$\\hat T_{L}$"):
                    super().compose(self.LogicalTCircuit_CF, self.logical_qregs[t][:] + self.logical_op_qregs[t][:], inplace=True)

        else:
            raise ValueError(f"'{method}' is not a valid method for the logical T gate")

    def tdg(self, *targets, method="Coherent_Feedback"):
            """
            Logical T^dagger gate

            Definition:
            [1    0         ]
            [0    e^(-ipi/4)]
            """

            if len(targets) == 1 and hasattr(targets[0], "__iter__"):
                targets = targets[0]

            if method == "LCU_Corrected":
                for t in targets:
                    with self.box(label="logical.logicalop.t.lcu_corrected:$\\hat{T^\\dagger}_{L}$"):
                        super().h(self.logical_op_qregs[t][0])
                        super().tdg(self.logical_op_qregs[t][0])
                        super().h(self.logical_op_qregs[t][0])
                        super().compose(self.LogicalZCircuit.control(1), [self.logical_op_qregs[t][0]] + self.logical_qregs[t][:], inplace=True)
                        super().h(self.logical_op_qregs[t][0])

                        super().append(Measure(), [self.logical_op_qregs[t][0]], [self.logical_op_meas_cregs[t][0]], copy=False)
                        super().reset(self.logical_op_qregs[t][0])

                        with super().if_test((self.logical_op_meas_cregs[t][0], 1)) as _else:
                            self.sdg(t, method='LCU_corrected')
                        with _else:
                            pass

            elif method == "Coherent_Feedback":
                for t in targets:
                    with self.box(label="logical.logicalop.t.coherent_feedback:$\\hat{T^\\dagger}_{L}$"):
                        super().compose(self.LogicalTdgCircuit_CF, self.logical_qregs[t][:] + self.logical_op_qregs[t][:], inplace=True)

            else:
                raise ValueError(f"'{method}' is not a valid method for the logical T^dagger gate")

    def cx(self, control, *_targets, method="Ancilla_Assisted"):
        """
        Logical Controlled-PauliX gate
        """

        if hasattr(_targets, "__iter__"):
            targets = _targets
        else:
            targets = [_targets]

        # @TODO - implement a better, more generalized CNOT gate
        if method == "Ancilla_Assisted":
            for t in targets:
                with self.box(label="logical.logicalop.cx.ancilla_assisted:$\\hat{CX}_{L}$"):
                    super().h(self.logical_op_qregs[t][0])
                    super().compose(self.LogicalZCircuit.control(1), [self.logical_op_qregs[t][0]] + self.logical_qregs[control][:], inplace=True)
                    super().h(self.logical_op_qregs[t][0])
                    super().compose(self.LogicalXCircuit.control(1), [self.logical_op_qregs[t][0]] + self.logical_qregs[t][:], inplace=True)
                    super().h(self.logical_op_qregs[t][0])
                    super().compose(self.LogicalZCircuit.control(1), [self.logical_op_qregs[t][0]] + self.logical_qregs[control][:], inplace=True)
                    super().h(self.logical_op_qregs[t][0])
        elif method == "Transversal_Uniform":
            for t in targets:
                with self.box(label="logical.logicalop.cx.transversal_uniform:$\\hat{CX}_{L}$"):
                    super().cx(self.logical_qregs[control][:], self.logical_qregs[t][:])
        else:
            raise ValueError(f"'{method}' is not a valid method for the logical CX gate")

    def cz(self, control, *_targets, method="Ancilla_Assisted"):
        """
        Logical Controlled-PauliZ gate
        """

        if hasattr(_targets, "__iter__"):
            targets = _targets
        else:
            targets = [_targets]

        # @TODO - implement a better, more generalized CZ gate
        if method == "Ancilla_Assisted":
            for t in targets:
                with self.box(label="logical.logicalop.cx.ancilla_assisted:$\\hat{CZ}_{L}$"):
                    super().h(self.logical_op_qregs[t][0])
                    super().compose(self.LogicalZCircuit.control(1), [self.logical_op_qregs[t][0]] + self.logical_qregs[control][:], inplace=True)
                    super().h(self.logical_op_qregs[t][0])
                    super().compose(self.LogicalZCircuit.control(1), [self.logical_op_qregs[t][0]] + self.logical_qregs[t][:], inplace=True)
                    super().h(self.logical_op_qregs[t][0])
                    super().compose(self.LogicalZCircuit.control(1), [self.logical_op_qregs[t][0]] + self.logical_qregs[control][:], inplace=True)
                    super().h(self.logical_op_qregs[t][0])
        elif method == "Transversal_Uniform":
            for t in targets:
                with self.box(label="logical.logicalop.cx.transversal_uniform:$\\hat{CZ}_{L}$"):
                    super().cz(self.logical_qregs[control][:], self.logical_qregs[t][:])
        else:
            raise ValueError(f"'{method}' is not a valid method for the logical CZ gate")
        
    def cy(self, control, *_targets, method="Ancilla_Assisted"):
        """
        Logical Controlled-PauliY gate
        """

        if hasattr(_targets, "__iter__"):
            targets = _targets
        else:
            targets = [_targets]

        # @TODO - implement a better, more generalized CY gate
        if method == "Ancilla_Assisted":
            for t in targets:
                with self.box(label="logical.logicalop.cx.ancilla_assisted:$\\hat{CY}_{L}$"):
                    super().h(self.logical_op_qregs[t][0])
                    super().compose(self.LogicalZCircuit.control(1), [self.logical_op_qregs[t][0]] + self.logical_qregs[control][:], inplace=True)
                    super().h(self.logical_op_qregs[t][0])
                    super().s(self.logical_op_qregs[t][0])
                    super().compose(self.LogicalZCircuit.control(1), [self.logical_op_qregs[t][0]] + self.logical_qregs[t][:], inplace=True)
                    super().compose(self.LogicalXCircuit.control(1), [self.logical_op_qregs[t][0]] + self.logical_qregs[t][:], inplace=True)
                    super().h(self.logical_op_qregs[t][0])
                    super().compose(self.LogicalZCircuit.control(1), [self.logical_op_qregs[t][0]] + self.logical_qregs[control][:], inplace=True)
                    super().h(self.logical_op_qregs[t][0])

        elif method == "Transversal_Uniform":
            for t in targets:
                with self.box(label="logical.logicalop.cx.transversal_uniform:$\\hat{CY}_{L}$"):
                    super().cx(self.logical_qregs[control][:], self.logical_qregs[t][:])
        else:
            raise ValueError(f"'{method}' is not a valid method for the logical CY gate")

    def mcmt(self, gate, controls, targets):
        """
        Logical Multi-Control Multi-Target gate
        """

        if len(controls) == 1 and hasattr(controls[0], "__iter__"):
            controls = controls[0]

        if len(targets) == 1 and hasattr(targets[0], "__iter__"):
            targets = targets[0]

        control_qubits = [self.logical_qregs[c][:] for c in controls]
        target_qubits = [self.logical_qregs[t][:] for t in targets]

        if not set(control_qubits).isdisjoint(target_qubits):
            raise ValueError("Qubit(s) specified as both control and target")

        with self.box(label="logical.logicalop.mcmt.default:$\\hat{MCMT}_{L}$"):
            super().append(gate.control(len(controls)), control_qubits + target_qubits)

    def rx(self, theta: float, targets, method = "S-K", depth = 10, recursion_degree = 1, box=True):
        """
        Logical Single-Target Rotation Gate
        
        method = "LCU" -> linear combination of unitaries or "S-K" -> solovay-kitaev algorithm or "OAA" -> oblivious amplitude amplification
        
        theta in radians
        """
        if hasattr(targets, "__iter__"):
            targets = targets
        else:
            targets = [targets]
        
        if method == "S-K":
            self.r(
                "x", 
                targets,
                theta,
                label="Rx",
                depth = depth, 
                recursion_degree = recursion_degree, 
                box = box)
        elif method == "OAA":
            raise NotImplementedError("Method not implemented.")
        else:
            raise ValueError("{method} is not a valid method.")
            
    def ry(self, theta: float, targets, method = "S-K", depth = 10, recursion_degree = 1, box=True):
        """
        Logical Single-Target Rotation Gate
        
        method = "LCU" or "S-K"
        
        theta in radians
        """
        if hasattr(targets, "__iter__"):
            targets = targets
        else:
            targets = [targets]
        
        if method == "S-K":
            self.r(
                "y", 
                targets,
                theta,
                label="Ry",
                depth = depth, 
                recursion_degree = recursion_degree, 
                box = box)
        elif method == "OAA":
            raise NotImplementedError("Method not implemented.")
        else:
            raise ValueError("{method} is not a valid method.")
            
    def rz(self, theta: float, targets, method = "S-K", depth = 10, recursion_degree = 1, box=True):
        """
        Logical Single-Target Rotation Gate
        
        method = "LCU" or "S-K"
        
        theta in radians
        """
        if hasattr(targets, "__iter__"):
            targets = targets
        else:
            targets = [targets]
        
        if method == "S-K":
            self.r(
                "z", 
                targets,
                theta,
                label="Rz",
                depth = depth, 
                recursion_degree = recursion_degree, 
                box = box)
        elif method == "OAA":
            raise NotImplementedError("Method not implemented.")
        else:
            raise ValueError("{method} is not a valid method.")
            
    def rxx(self, theta: float, targets, method = "S-K", depth = 10, recursion_degree = 1, box=True):
        if hasattr(targets, "__iter__"):
            targets = targets
        else:
            targets = [targets]
            
        if len(targets) != 2:
            raise AssertionError("Number of target qubits must be 2.")
        
        if method == "S-K":
            self.r(
                "xx", 
                targets,
                theta,
                label="Rxx",
                depth = depth, 
                recursion_degree = recursion_degree, 
                box = box)
        elif method == "OAA":
            raise NotImplementedError("Method not implemented.")
        else:
            raise ValueError("{method} is not a valid method.")
            
    def ryy(self, theta: float, targets, method = "S-K", depth = 10, recursion_degree = 1, box=True):
        if hasattr(targets, "__iter__"):
            targets = targets
        else:
            targets = [targets]
        
        if len(targets) != 2:
            raise AssertionError("Number of target qubits must be 2.")
        
        if method == "S-K":
            self.r(
                "yy", 
                targets,
                theta,
                label="Ryy",
                depth = depth, 
                recursion_degree = recursion_degree, 
                box = box)
        elif method == "OAA":
            raise NotImplementedError("Method not implemented.")
        else:
            raise ValueError("{method} is not a valid method.")
            
    def rzz(self, theta: float, targets, method = "S-K", depth = 10, recursion_degree = 1, box=True):
        if hasattr(targets, "__iter__"):
            targets = targets
        else:
            targets = [targets]
        
        if len(targets) != 2:
            raise AssertionError("Number of target qubits must be 2.")
        
        if method == "S-K":
            self.r(
                "zz", 
                targets,
                theta,
                label="Rzz",
                depth = depth, 
                recursion_degree = recursion_degree, 
                box = box
                )
        elif method == "OAA":
            raise NotImplementedError("Method not implemented.")
        else:
            raise ValueError("{method} is not a valid method.")

    def append_sk_decomposition(self, circuit, targets, label="U", depth=10, recursion_degree=1, box=False, return_subcircuit=False):
        
        basis = ["s", "sdg", "t", "tdg", "h", "x", "y", "z", "cz"]
        approx = generate_basic_approximations(basis, depth=depth)
        skd = SolovayKitaev(recursion_degree=recursion_degree, basic_approximations=approx)

        discretized_sub_qc = skd(circuit)
        
        def append_all():
            for i in range(len(discretized_sub_qc.data)):
                circuit_instruction = discretized_sub_qc.data[i]
                qargs = [targets[discretized_sub_qc.qubits.index(qubit)] for qubit in circuit_instruction.qubits]
                self.append(circuit_instruction, qargs=qargs)
        
        if box:
            with self.box(label=f"logical.logicalop.{label}"):
                append_all()
        else:
            append_all()
            
        if return_subcircuit:
            return discretized_sub_qc

    def r(self, axis, targets, theta = 0, label = "R", depth = 10, recursion_degree = 1, box=True, method = "S-K"):
        if isinstance(axis, str):        
            # In form "instruction.name: (Gate, num_targets_per_gate)"
            valid_gates = {"x": (RXGate, 1), "y": (RYGate, 1), "z": (RZGate, 1), "xx": (RXXGate, 2), "yy": (RYYGate, 2), "zz": (RZZGate, 2)} 
            
            if axis not in list(valid_gates.keys()):
                raise NotImplementedError(f"Invalid input '{axis}' for argument 'axis'.")
            
            if label == "R":
                label = label + axis
                
        elif isinstance(axis, list):
            if len(axis) == 3:
                raise NotImplementedError("Arbitrary rotation axes are not yet implemented.")
            else:
                raise ValueError(f"'axis' is list of invalid length ({len(axis)}). 'axis' must have length 3.")
            
        else:
            raise TypeError(f"Provided 'axis' is not an instance of an allowed type (str, int).")
        
        if method == "S-K":
            gate_base, num_target_qubits = valid_gates[axis]
            gate = gate_base(theta)
            
            sub_qc = QuantumCircuit(num_target_qubits)
        
            def apply_Rzz(sub_qc):
                sub_qc.cx(0, 1)
                sub_qc.rz(theta, 1)
                sub_qc.cx(0, 1)
        
            match axis:
                case "xx":
                    sub_qc.h([0, 1])
                    apply_Rzz(sub_qc)
                    sub_qc.h([0, 1])
                case "yy":
                    sub_qc.rx(np.pi / 2, [0, 1])
                    apply_Rzz(sub_qc)
                    sub_qc.rx(-np.pi / 2, [0, 1])
                case "zz":
                    apply_Rzz(sub_qc)
                case _:        
                    sub_qc.append(gate, qargs = list(range(num_target_qubits)))
            
            self.append_sk_decomposition(sub_qc, targets, label=label, depth=depth, recursion_degree=recursion_degree, box=box)
            
        elif method == "OAA":
            raise NotImplementedError("Method not implemented.")
        else:
            raise ValueError("{method} is not a valid method.")

    # Input could be: 1. (CircuitInstruction(name="...", qargs="...", cargs="..."), qargs=None, cargs=None)
    #                 2. (Instruction(name="..."), qargs=[..], cargs=[...])
    def append(self, instruction, qargs=None, cargs=None, copy=True):
        if isinstance(instruction, str):
            operation = instruction
        elif hasattr(instruction, "name"):
            operation = instruction.name.lower()
        else:
            raise ValueError(f"Instruction could not be parsed: {instruction}")

        # @TODO - copy.deepcopy fails for some instructions (such as IfElseOp), figure out a fix
        if copy and not isinstance(instruction, bool):
            try:
                instruction = copy.deepcopy(instruction)
            except:
                pass
                # print(f"WARNING: LogicalCircuit does not support append-by-copy for instruction '{operation}', ignoring")

        if qargs is None:
            if hasattr(instruction, "qubits"):
                qargs = instruction.qubits
                qubits = [qubit._index for qubit in instruction.qubits]
        else:
            if all([isinstance(qarg, int) for qarg in qargs]):
                qubits = qargs
            elif all([isinstance(qarg, Bit) for qarg in qargs]):
                qubits = [qarg._index for qarg in qargs]
            elif hasattr(instruction, "qubits"):
                qubits = [qubit._index for qubit in instruction.qubits]
            elif all([isinstance(qarg, QuantumRegister) for qarg in qargs]):
                qubits = [qubit._index for qarg in qargs for qubit in qarg]
            else:
                raise ValueError(f"At least one of the following quantum arguments to operation '{operation}' are unrecognized: {qargs}")

        if cargs is None:
            if hasattr(instruction, "clbits"):
                cargs = instruction.clbits
                clbits = [clbit._index for clbit in instruction.clbits]
        else:
            if all([isinstance(carg, int) for carg in cargs]):
                clbits = cargs
            elif all([isinstance(carg, Bit) for carg in cargs]):
                qubits = [carg._index for carg in cargs]
            elif hasattr(instruction, "clbits"):
                clbits = [clbit._index for clbit in instruction.clbits]
            elif all([isinstance(carg, ClassicalRegister) for carg in cargs]):
                clbits = [clbit._index for carg in cargs for clbit in carg]
            else:
                raise ValueError(f"At least one of the following classical arguments to operation '{operation}' are unrecognized: {cargs}")

        match operation:
            case "h":
                self.h(qubits)
            case "x":
                self.x(qubits)
            case "y":
                self.y(qubits)
            case "z":
                self.z(qubits)
            case "s":
                self.s(qubits)
            case "sdg":
                self.sdg(qubits)
            case "t":
                self.t(qubits)
            case "tdg":
                self.tdg(qubits)
            case "cx":
                control_qubit = instruction.qubits[0]._index
                target_qubit = instruction.qubits[1]._index
                self.cx(control_qubit, target_qubit)
            case "cz":
                control_qubit = instruction.qubits[0]._index
                target_qubit = instruction.qubits[1]._index
                self.cz(control_qubit, target_qubit)
            case "cy":
                control_qubit = instruction.qubits[0]._index
                target_qubit = instruction.qubits[1]._index
                self.cy(control_qubit, target_qubit)
            case "rx":
                theta = instruction.params[0]
                self.rx(theta, qubits)
            case "ry":
                theta = instruction.params[0]
                self.ry(theta, qubits)
            case "rz":
                theta = instruction.params[0]
                self.rz(theta, qubits)
            case "Rxx":
                theta = instruction.params[0]
                self.rxx(theta, qubits)
            case "Ryy":
                theta = instruction.params[0]
                self.ryy(theta, qubits)    
            case "Rzz":
                theta = instruction.params[0]
                self.rzz(theta, qubits)
            # @TODO Fix code to initialize LogicalCircuit to arbitrary logical state.
            #case "initialize":
            #    sv = instruction.params
            #    #if isinstance(sv, list):
            #    #    sv = Statevector(sv)
            #        
            #    lsv = LogicalStatevector(sv, len(qubits), self.label, self.stabilizer_tableau)
            #    self.initialize(lsv.data)
            case "mcmt":
                raise NotImplementedError(f"Physical operation 'MCMT' does not have physical gate conversion implemented!")
            case "measure":
                # If classical bits for measurement aren't specified, default to match logical qubit indices
                if clbits is None:
                    clbits = qubits

                # @TODO - decide best default behavior here (maybe we should ask during from_physical_circuit)
                self.measure(qubits, clbits, with_error_correction=True)
            case "barrier":
                pass
            case _:
                # @TODO - identify a better way of providing these warnings
                # print(f"WARNING: Physical operation '{operation.upper()}' does not have a logical counterpart implemented! Defaulting to physical operation.")

                instruction = super().append(instruction, qargs, cargs, copy=copy)

        return instruction

    ###########################
    ##### Utility methods #####
    ###########################

    # Adds a desired error for testing
    def add_error(self, l_ind, p_ind, error_type):
        if error_type == 'X':
            super().x(self.logical_qregs[l_ind][p_ind])
        if error_type == 'Z':
            super().z(self.logical_qregs[l_ind][p_ind])

    # @TODO - find alternative to classical methods, possibly by implementing upstream

    # Set values of classical bits
    def set_cbit(self, cbit, value):
        if value == 0:
            super().append(Measure(), [self.cbit_setter_qreg[0]], [cbit], copy=False)
        else:
            super().append(Measure(), [self.cbit_setter_qreg[1]], [cbit], copy=False)

    # Performs a NOT statement on a classical bit
    def cbit_not(self, cbit):
        with self.if_test(expr.lift(cbit)) as _else:
            self.set_cbit(cbit, 0)
        with _else:
            self.set_cbit(cbit, 1)

    # Performs AND and NOT statements on multiple classical bits, e.g. (~c[0] & ~c[1] & c[2])
    def cbit_and(self, cbits, values):
        result = expr.bit_not(cbits[0]) if values[0] == 0 else expr.lift(cbits[0])
        for n in range(len(cbits)-1):
            result = expr.bit_and(result, expr.bit_not(cbits[n+1])) if values[n+1] == 0 else expr.bit_and(result, cbits[n+1])
        return result

    # XOR multiple classical bits
    def cbit_xor(self, cbits):
        result = expr.lift(cbits[0])
        for n in range(len(cbits)-1):
            result = expr.bit_xor(result, cbits[n+1])
        return result

    ######################################
    ##### Visualization and analysis #####
    ######################################

    def draw(
        self,
        output=None,
        scale=None,
        filename=None,
        style=None,
        interactive=False,
        plot_barriers=True,
        reverse_bits=None,
        justify=None,
        vertical_compression="medium",
        idle_wires=None,
        with_layout=True,
        fold=None,
        # The type of ax is matplotlib.axes.Axes, but this is not a fixed dependency, so cannot be
        # safely forward-referenced.
        ax=None,
        initial_state=False,
        cregbundle=None,
        wire_order=None,
        expr_len=30,
        fold_qec=True,
        fold_logicalop=True,
    ):
        """
        LogicalCircuit drawer based on Qiskit circuit drawer
        """

        from .Visualization.LogicalCircuitVisualization import logical_circuit_drawer

        return logical_circuit_drawer(
            self,
            scale=scale,
            filename=filename,
            style=style,
            output=output,
            interactive=interactive,
            plot_barriers=plot_barriers,
            reverse_bits=reverse_bits,
            justify=justify,
            vertical_compression=vertical_compression,
            idle_wires=idle_wires,
            with_layout=with_layout,
            fold=fold,
            ax=ax,
            initial_state=initial_state,
            cregbundle=cregbundle,
            wire_order=wire_order,
            expr_len=expr_len,
            fold_qec=fold_qec,
            fold_logicalop=fold_logicalop,
        )

class LogicalQubit(list):
    """
    A single LogicalQubit
    """

    def __init__(self, regs=None, qregs=None, cregs=None):
        self._data = []
        self.qregs = []
        self.cregs = []
        if regs is not None:
            for reg in regs:
                if isinstance(reg, QuantumRegister):
                    self._data.append(reg)
                    self.qregs.append(reg)
                elif isinstance(reg, ClassicalRegister):
                    self._data.append(reg)
                    self.cregs.append(reg)
                else:
                    raise TypeError()
        if qregs is not None:
            for qreg in qregs:
                if isinstance(qreg, QuantumRegister):
                    self._data.append(qreg)
                    self.qregs.append(qreg)
                else:
                    raise TypeError()
        if cregs is not None:
            for creg in cregs:
                if isinstance(creg, ClassicalRegister):
                    self._data.append(creg)
                    self.cregs.append(creg)
                else:
                    raise TypeError()

        raise NotImplementedError("LogicalQubit is not yet fully implemented")

class LogicalRegister(list):
    """
    A register containing LogicalQubits
    """

    def __init__(self, qregs=None, cregs=None):
        self.qregs = qregs
        self.cregs = cregs
        raise NotImplementedError("LogicalRegister is not yet fully implemented")    

class LogicalStatevector(Statevector):
    """
    A LogicalStatevector
    """

    def __init__(
        self,
        data: np.ndarray | QuantumCircuit | LogicalCircuit | Statevector,
        logical_circuit: LogicalCircuit = None,
        n_logical_qubits: int = None,
        label: Iterable[int] = None,
        stabilizer_tableau: Iterable[str] = None,
        dims: int = None
    ):
        """Initialize a LogicalStatevector object.

        Args:
            data: The data from which to construct the LogicalStatevector. This can be a
                `LogicalCircuit` object, a qiskit `Statevector`, or a complex data vector.
            logical_circuit: 
            n_logical_qubits: The number of logical qubits encoded in the statevector.
            label: The label of the quantum error correction code, i.e., [[n,k,d]] (given as a
                tuple), with n the number of physical qubits, k the number of logical qubits, and
                d the distance.
            stabilizer_tableau: The set of stabilizers for the QECC.
            dims: The subsystem dimension of the state.
        
        Raises:
            ValueError: if `data` is a mixed state or otherwise invalid, or if other parameters
                are not given when constructing from a complex vector.
            TypeError: if `data` is not a supported data type.
            NotImplementedError: if `n_logical_qubits` is greater than 1 or if `data` is a
                qiskit QuantumCircuit.
        """
        if isinstance(data, LogicalCircuit):
            self.logical_circuit = copy.deepcopy(data)
            self.n_logical_qubits = self.logical_circuit.n_logical_qubits
            self.label = self.logical_circuit.label
            self.stabilizer_tableau = self.logical_circuit.stabilizer_tableau

            # Get a list of non-logical qubits to later partial trace over
            # This must happen before unboxing because QuantumCircuit does not 
            # store logical_qreg information.
            non_data_qubits = []
            count = 0
            for qreg in self.logical_circuit.qregs:
                if qreg in self.logical_circuit.logical_qregs:
                    count += qreg.size
                else:
                    non_data_qubits = non_data_qubits + list(range(count, count+qreg.size))
                    count += qreg.size

            # Circuit-to-instruction conversions can't handle QEC (due to ControlFlowOps and measurements),
            # nor can they handle other BoxOp's that may appear in the circuit (such as logical gates)
            pm_unbox = PassManager([ClearQEC(), UnBox()])
            while "box" in self.logical_circuit.count_ops():
                self.logical_circuit = pm_unbox.run(self.logical_circuit)

            # Circuit-to-instruction conversions can't handle other measurements either
            self.logical_circuit.remove_final_measurements()

            # First, construct a Statevector object for the full system
            sv_full = Statevector(data=self.logical_circuit, dims=dims)

            # Then, partial trace over the non-data qubits to obtain a DensityMatrix
            dm_partial = partial_trace(sv_full, non_data_qubits)

            try:
                sv_partial = dm_partial.to_statevector()
                super().__init__(data=sv_partial.data, dims=dims)
            except QiskitError as e:
                raise ValueError("Unable to construct LogicalStatevector from LogicalCircuit because data qubits are in a mixed state; a LogicalDensityMatrix may be the best alternative") from e
            except Exception as e:
                raise ValueError("Unable to construct LogicalStatevector from LogicalCircuit") from e
        elif isinstance(data, QuantumCircuit):
            # @TODO - determine a good way to handle one (or both) of the two possible cases:
            #           1. QuantumCircuit was actually a LogicalCircuit that was casted at some point and thus should be treated like a LogicalCircuit
            #           2. QuantumCircuit is a regular physical circuit and should first be converted into a LogicalCircuit

            raise NotImplementedError("LogicalStatevector construction from QuantumCircuit is not yet supported; please provide a LogicalCircuit or an amplitude iterable")
        elif isinstance(data, LogicalStatevector):
            raise TypeError("Cannot construct a LogicalStatevector from another LogicalStatevector in this way; a deepcopy may be more appropriate for this purpose")
        elif isinstance(data, Statevector):
            if n_logical_qubits and label and stabilizer_tableau and logical_circuit:
                self.logical_circuit = copy.deepcopy(logical_circuit)
                self.n_logical_qubits = n_logical_qubits
                self.label = label
                self.stabilizer_tableau = stabilizer_tableau

                # First, construct a Statevector object for the full system
                sv_full = Statevector(data=data._data, dims=dims)

                # Then, partial trace over the non-data qubits to obtain a DensityMatrix
                non_data_qubits = []
                count = 0
                for qreg in self.logical_circuit.qregs:
                    if qreg in self.logical_circuit.logical_qregs:
                        count += qreg.size
                    else:
                        non_data_qubits = non_data_qubits + list(range(count, count+qreg.size))
                        count += qreg.size
                dm_partial = partial_trace(sv_full, non_data_qubits)

                try:
                    sv_partial = dm_partial.to_statevector()
                    super().__init__(data=sv_partial.data, dims=dims)
                except QiskitError as e:
                    raise ValueError("Unable to construct LogicalStatevector from Statevector because data qubits are in a mixed state; a LogicalDensityMatrix may be the best alternative") from e
                except Exception as e:
                    raise ValueError("Unable to construct LogicalStatevector from Statevector") from e
            else:
                raise ValueError("LogicalStatevector construction from a Statevector requires n_logical_qubits, label, stabilizer_tableau, and logical_circuit to all be specified")
        elif hasattr(data, "__iter__"):
            if n_logical_qubits and label and stabilizer_tableau:
                self.logical_circuit = None
                self.n_logical_qubits = n_logical_qubits
                self.label = label
                self.stabilizer_tableau = stabilizer_tableau

                super().__init__(data=data, dims=dims)
            else:
                raise ValueError("LogicalStatevector construction from an amplitude iterable requires n_logical_qubits, label, and stabilizer_tableau to all be specified")
        else:
            raise TypeError(f"Object of type {type(data)} is not a valid data input for LogicalStatevector")


        # Defer computation until necessary
        self._logical_decomposition = None

    @classmethod
    def from_counts(
        cls,
        counts: dict[str, int],
        n_logical_qubits: int,
        label: Iterable[int],
        stabilizer_tableau: Iterable[str],
        basis: str = "physical"
    ) -> LogicalStatevector:
        """Construct a LogicalStatevector from measurement counts.

        Args:
            counts (dict): The set of counts measured from a circuit execution.
            n_logical_qubits (int): The number of logical qubits.
            label (tuple): The quantum error correction code [[n,k,d]] (given as a tuple).
            stabilizer_tableau (Iterable[str]): The set of stabilizers for the QECC.
            basis (str): The basis in which each respective count's vector is given, physical or logical.
        
        Returns:
            The normalized LogicalStatevector constructed from the counts.
        
        Raises:
            ValueError: if the counts format could not be parsed or if `basis` is invalid.
        """
        outcomes_raw = [key.replace(" ", "") for key in counts.keys()]
        outcomes = []
        if basis == "physical":
            for outcome_raw in outcomes_raw:
                # @TODO - find a more reliable method that does not rely on the current indexing
                # Get substring corresponding to logical measurement result
                binary = ""
                if all([char in ["0", "1"] for char in outcome_raw]):
                    binary = outcome_raw
                elif outcome_raw.startswith("0b"):
                    binary = outcome_raw[2:]
                elif outcome_raw.startswith("0x"):
                    binary = str(bin(int(outcome_raw, 16)))[2:]
                else:
                    raise ValueError("Could not resolve count format")

                # If binary string is short (e.g. 0b0 or 0x0), pad to have length equalling number of data qubits
                if len(binary) < label[0]:
                    binary = "0"*(label[0] - len(binary)) + binary
                    outcomes.append(binary)
                else:
                    outcomes.append(binary[1:1+label[0]])
        elif basis == "logical":
            # @TODO - make sure this is correct
            for outcome_raw in outcomes_raw:
                # No parsing required
                outcomes.append(outcome_raw)
        else:
            raise ValueError(f"'{basis}' is not a valid basis for LogicalStatevector array representation")

        counts = np.array(list(counts.values()))
        probabilities = counts/np.sum(counts)
        amplitudes = np.sqrt(probabilities)

        lsv_unnormalized = None
        for amplitude, outcome in zip(amplitudes, outcomes):
            lsv_term = amplitude * LogicalStatevector.from_basis_str(basis_str=outcome, n_logical_qubits=n_logical_qubits, label=label, stabilizer_tableau=stabilizer_tableau)

            if lsv_unnormalized is None:
                lsv_unnormalized = lsv_term
            else:
                lsv_unnormalized += lsv_term

        norm = np.linalg.norm(lsv_unnormalized)
        if not np.isclose(norm, 0.0):
            lsv_normalized = lsv_unnormalized / norm
        else:
            lsv_normalized = lsv_unnormalized

        return lsv_normalized

    @classmethod
    def from_basis_str(cls, basis_str, n_logical_qubits, label, stabilizer_tableau, basis="physical"):
        """Construct a LogicalStatevector, an element of the logical computational basis, from
            a basis string.
        
        Args:
            basis_str (str): Either a binary bitstring or its hex equivalent to identify the basis.
            n_logical_qubits (int): Number of logical qubits.
            label (tuple): The label of the quantum error correction code [[n,k,d]] (as a tuple).
            stabilizer_tableau (Iterable[str]): The set of stabilizers for the QECC.
            basis (str): The basis in which each respective count's vector is given, physical or logical.
        
        Returns:
            :py:class:`~LogicalQ.Logical.LogicalStatevector`: The LogicalStatevector of
                the given basis state.
        
        Raises:
            ValueError: if the `basis_str` could not be parsed due to improper format.
        """
        if all([char in ["0", "1"] for char in basis_str]):
            d = 2**(len(basis_str))
            basis_idx = int(basis_str, 2)
        elif basis_str.startswith("0b"):
            d = 2**len(basis_str-2)
            basis_idx = int(basis_str[2:], 2)
        elif basis_str.startswith("0x"):
            d = 16**(len(basis_str)-2)
            basis_idx = int(basis_str[2:], 16)
        else:
            raise ValueError("Could not resolve basis_str format")

        basis_vector = np.zeros((d,))
        basis_vector[basis_idx] = 1.0

        lsv = cls(data=basis_vector, n_logical_qubits=n_logical_qubits, label=label, stabilizer_tableau=stabilizer_tableau)
        return lsv

    @property
    def logical_decomposition(self, atol=1E-13):
        """Give a decomposition of a LogicalStatevector into the logical basis.

        Args:
            atol (float): Tolerance within which to set probability amplitude to zero.
        
        Returns:
            :py:type:`np.ndarray`: The set of coefficients :math:`\\alpha_i, \\delta`, where
                :math:`|\\psi\\rangle = \\sum_{x=0}^{2^n - 1}\\alpha_x|x\\rangle + \\delta|\\psi^\\perp\\rangle`,
                where :math:`|\\psi^\\perp\\rangle` is the component of the state vector not in the
                codespace. Note that coefficients are returned in ascending order of value, e.g.,
                000, 001, 010, 011, 100, 101, etc.
        """
        if self._logical_decomposition is None:
            # generate all possible initial states, in ascending order of value
            states = [[]]
            for i in range(self.n_logical_qubits):
                new_states = []
                for state in states:
                    new_states.append([*state, 0])
                    new_states.append([*state, 1])
                states = new_states.copy()
            
            lqcs = [LogicalCircuit(self.n_logical_qubits, self.label, self.stabilizer_tableau)
                    for i in range(np.pow(2, self.n_logical_qubits))]
            for (i, state) in enumerate(states):
                lqcs[i].encode(range(self.n_logical_qubits), initial_states=state)
            lsvs = [LogicalStatevector(lqc) for lqc in lqcs]

            coeffs = [0.] * np.pow(2, self.n_logical_qubits)
            for i in range(len(coeffs)):
                coeffs[i] = np.vdot(lsvs[i].data, self.data)
            delta = np.sqrt(np.maximum(0.0, 1 - np.sum(np.pow(np.abs(coeffs),2))))

            self._logical_decomposition = np.array([*coeffs, delta])
            real_part = np.real(self._logical_decomposition)
            imag_part = np.imag(self._logical_decomposition)
            real_part[np.abs(real_part) < atol] = 0.0
            imag_part[np.abs(imag_part) < atol] = 0.0
            self._logical_decomposition = real_part + 1.j*imag_part
            self._logical_decomposition /= np.linalg.norm(self._logical_decomposition)

        return self._logical_decomposition

    # @TODO - find a way to let basis="logical" by default but without causing a recursive loop during logical_decomposition computation
    def __array__(self, basis="physical", dtype=None, copy=_numpy_compat.COPY_ONLY_IF_NEEDED):
        """Return an array representation of the object.

        Args:
            basis (str): The basis, physical or logical, in which to represent the vector.
            dtype (data-type): The desired data-type for the array.
            copy (bool): If `True`, then the array data is copied
        
        Returns:
            :py:type:`np.ndarray`: The :py:meth:`~LogicalQ.Logical.LogicalStatevector.logical_decomposition`
                if basis is logical and the statevector data otherwise.
        
        Raises:
            ValueError: if the basis is invalid.
        """
        dtype = self.data.dtype if dtype is None else dtype

        if basis == "logical":
            return np.array(self.logical_decomposition, dtype=dtype, copy=copy)
        elif basis == "physical":
            return np.array(self.data, dtype=dtype, copy=copy)
        else:
            raise ValueError(f"'{basis}' is not a valid basis for LogicalStatevector array representation")

    def __repr__(self, basis="logical"):
        """Return a string representation of the statevector.

        Args:
            basis (str): The basis, logical or physical, in which to return the array.
        
        Returns:
            :py:type:`np.ndarray`: String representation of the statevector in the requested basis.

        Raises:
            ValueError: if the basis is invalid.
        """
        if basis == "logical":
            data = self.logical_decomposition
        elif basis == "physical":
            data = self.data
        else:
            raise ValueError(f"'{basis}' is not a valid basis for LogicalStatevector string representation")

        prefix = "Statevector("
        pad = len(prefix) * " "
        return (
            f"{prefix}{np.array2string(data, separator=', ', prefix=prefix)},\n{pad}"
            f"dims={self._op_shape.dims_l()})"
        )

    def draw(self, output=None):
        """Return a visual representation of the statevector in the logical basis.

        Args:
            output (str): The method in which to draw. Valid choices are `text`, `latex`, and
                `latex_source`.
            
        Returns:
            :py:type:`str` or :py:class:`IPython.display.Latex`: String or LaTeX representation of the statevector.

        Raises:
            ValueError: if the draw method is invalid.
        """
        if output is None: output = "text"

        # @TODO - display scientific notation correctly in string formatting
        if output == "text":
            text = f"{self.logical_decomposition[0]} |0>_L + {self.logical_decomposition[1]} |1>_L + {self.logical_decomposition[2]} |psi_L^perp>"
            return text
        elif output == "latex":
            from IPython.display import Latex

            latex = ""
            latex += "$$"
            latex += "\\begin{align}\n"
            latex += f"{self.logical_decomposition[0]} \\ket{{0}}_L + {self.logical_decomposition[1]} \\ket{{1}}_L + {self.logical_decomposition[2]} \\ket{{\\psi_L^\\perp}}"
            latex += "\\end{align}\n"
            latex += "$$"

            return Latex(latex)
        elif output == "latex_source":
            latex = ""
            latex += "$$"
            latex += "\\begin{align}\n"
            latex += f"{self.logical_decomposition[0]} \\ket{{0}}_L + {self.logical_decomposition[1]} \\ket{{1}}_L + {self.logical_decomposition[2]} \\ket{{\\psi_L^\\perp}}"
            latex += "\\end{align}\n"
            latex += "$$"

            return latex
        else:
            raise ValueError(f"'{output}' is not a valid LogicalStatevector draw method, please choose from 'text', 'latex', or 'latex_source'")

class LogicalDensityMatrix(DensityMatrix):
    """
    A LogicalDensityMatrix
    """

    def __init__(self, data, n_logical_qubits=None, label=None, stabilizer_tableau=None, dims=None):
        if isinstance(data, LogicalCircuit):
            self.logical_circuit = copy.deepcopy(data)
            self.n_logical_qubits = self.logical_circuit.n_logical_qubits
            self.label = self.logical_circuit.label
            self.stabilizer_tableau = self.logical_circuit.stabilizer_tableau

            # Circuit-to-instruction conversions can't handle QEC (due to ControlFlowOps and measurements),
            # nor can they handle other BoxOp's that may appear in the circuit (such as logical gates)
            pm_unbox = PassManager([ClearQEC(), UnBox()])
            while "box" in self.logical_circuit.count_ops():
                self.logical_circuit = pm_unbox.run(self.logical_circuit)

            # Circuit-to-instruction conversions can't handle other measurements either
            self.logical_circuit.remove_final_measurements()

            # First, construct a DensityMatrix object for the full system
            ldm_full = DensityMatrix(data=self.logical_circuit, dims=dims)

            # Then, partial trace over the non-data qubits to obtain a new DensityMatrix
            non_data_qubits = list(range(self.label[0], self.logical_circuit.num_qubits))
            ldm_partial = partial_trace(ldm_full, non_data_qubits)

            super().__init__(data=ldm_partial.data, dims=dims)
        elif isinstance(data, QuantumCircuit):
            # @TODO - determine a good way to handle one (or both) of the two possible cases:
            #           1. QuantumCircuit was actually a LogicalCircuit that was casted at some point and thus should be treated like a LogicalCircuit
            #           2. QuantumCircuit is a regular physical circuit and should first be converted into a LogicalCircuit

            raise NotImplementedError("LogicalDensityMatrix construction from QuantumCircuit is not yet supported; please provide a LogicalCircuit or an amplitude iterable")
        elif hasattr(data, "__iter__"):
            if not (np.log2(len(data)).is_integer() and data.shape == (len(data), len(data))):
                raise ValueError("LogicalDensityMatrix data must be a square matrix whose dimension is a power of 2.")
            if n_logical_qubits and label and stabilizer_tableau:
                self.logical_circuit = None
                self.n_logical_qubits = n_logical_qubits
                self.label = label
                self.stabilizer_tableau = stabilizer_tableau

                super().__init__(data=data, dims=dims)
            else:
                raise ValueError("LogicalDensityMatrix construction from an amplitude iterable requires n_logical_qubits, label, and stabilizer_tableau to all be specified")
        else:
            raise TypeError(f"Object of type {type(data)} is not a valid data input for LogicalDensityMatrix")

        # @TODO
        if self.n_logical_qubits > 1:
            raise NotImplementedError("LogicalDensityMatrix does not yet support circuits with multiple logical qubits")

        # @TODO - how will this be obtained for a logical density matrix?
        # Defer computation until necessary
        self._logical_decomposition = None

        print("WARNING - LogicalDensityMatrix has not been fully implemented yet!")

    # @TODO - implement
    @property
    def logical_decomposition(self, atol=1E-13):
        raise NotImplementedError()

    def __repr__(self, basis="logical"):
        if basis == "logical":
            print("WARNING - Logical basis representation is not yet implemented, using physical representation instead")
            data = self.data

            # data = self.logical_decomposition
        elif basis == "physical":
            data = self.data
        else:
            raise ValueError(f"'{basis}' is not a valid basis for LogicalDensityMatrix string representation")

        prefix = "DensityMatrix("
        pad = len(prefix) * " "
        return (
            f"{prefix}{np.array2string(data, separator=', ', prefix=prefix)},\n"
            f"{pad}dims={self._op_shape.dims_l()})"
        )

def logical_state_fidelity(state1, state2):
    states = []

    if (
        (isinstance(state1, LogicalStatevector) and isinstance(state2, LogicalStatevector)) or
        (isinstance(state1, LogicalDensityMatrix) and isinstance(state2, LogicalDensityMatrix))
    ):
        # Special cases
        states.append(state1)
        states.append(state2)
    else:
        # Other cases
        for s, _state in enumerate([state1, state2]):
            if isinstance(_state, LogicalStatevector):
                # @TODO - determine whether this is the best thing to do in this case
                state = Statevector(_state.logical_decomposition[:2])
            elif isinstance(_state, LogicalDensityMatrix):
                # @TODO - determine the best thing to do in this case
                raise NotImplementedError("state_fidelity computation for LogicalDensityMatrix instances is not yet implemented")
            elif isinstance(_state, Statevector) or isinstance(_state, DensityMatrix):
                state = _state
            else:
                raise TypeError(f"Invalid type for state at index {s}: {type(_state)}")

            states.append(state)

    return state_fidelity(states[0], states[1], validate=False)

