#!/usr/bin/python3
# -*- coding: utf-8 -*-
#
# Copyright (c) 2022 Baidu, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
In this script, we implement the cross-platform estimation method described in
``Cross-Platform Verification of Intermediate Scale Quantum Devices`` [EVB+20]_,
which aims at determining the overlap between the quantum states generated by different quantum platforms.

References:

.. [EVB+20] Elben, Andreas, et al.
            "Cross-platform verification of intermediate scale quantum devices."
            Physical Review Letters 124.1 (2020): 010504.
"""
from QCompute import *
from QCompute.QPlatform.QOperation import CircuitLine
from QCompute.QPlatform.QOperation import RotationGate
import abc
from typing import List, Dict
from copy import deepcopy
import numpy as np
from scipy.spatial.distance import hamming
from scipy.stats import unitary_group
import collections
from datetime import datetime
import json
import matplotlib.pyplot as plt
from tqdm import tqdm

from qcompute_qep.exceptions.QEPError import ArgumentError
import qcompute_qep.estimation as estimation
from qcompute_qep.utils.types import QComputer, QProgram, number_of_qubits
from qcompute_qep.utils.utils import decompose_yzy
from qcompute_qep.utils.circuit import print_circuit, execute


class QuantumSnapshot(abc.ABC):
    """The Quantum Snapshot class.

    Used to record all information of a quantum circuit and quantum computer.
    """
    def __init__(self, qc_name: str, qc: QComputer = None, **kwargs):
        """The init function of the Quantum Snapshot class.

        Optional keywords list are:

            + `qubits`: List[int], default to None, the index of target qubit(s)
            + `counts`: Dict[str, Dict[str, int]], store the information of outcomes
            + `unitaries`: Dict[str: List[List[float]]], store the information of unitaries

        :param qc_name: str, the name of input quantum device
        :param qc: QComputer, quantum computer

        .. code-block:: python
            :linenos:

            ideal_baidu = QuantumSnapshot(qc_name='Baidu ideal', qc=QCompute.BackendName.LocalBaiduSim2, qubits=[0])
            qian_baidu = QuantumSnapshot(qc_name='Baidu Qian', qc=QCompute.BackendName.CloudBaiduQPUQian, qubits=[0])

        """
        self._qc_name = qc_name
        self._qc = qc
        self._qubits: List[int] = kwargs.get('qubits', None)
        self._counts: List[str] = kwargs.get('counts', [])
        self._unitaries: Dict[str, List[List[float]]] = kwargs.get('unitaries', dict())
        self._qp: str = ""

    def save_data(self, file_name: str = None):
        """Save the Quantum Snapshot information to file `qc_name_date.devinf`

        Note: if there already exists a file, then we will overwrite.
        """
        if file_name is None:
            file_name = self._qc_name + datetime.now().strftime("_%m_%d_%H") + ".devinf"
        file_data = {'qubits': self._qubits,
                     'counts': self._counts,
                     'unitaries': self._unitaries,
                     'qp': self._qp}
        with open(file_name, 'w') as f:
            json.dump(file_data, f, separators=(',', ':'), indent=4)

    @property
    def qc(self):
        return self._qc

    @property
    def qc_name(self):
        return self._qc_name

    @property
    def qubits(self):
        return self._qubits

    @property
    def counts(self):
        return self._counts

    @property
    def unitaries(self):
        return self._unitaries

    @property
    def qp(self):
        return self._qp

    @qp.setter
    def qp(self, new_qp: str):
        self._qp = new_qp

    @counts.setter
    def counts(self, new_counts):
        self._counts = new_counts

    @unitaries.setter
    def unitaries(self, new_unitaries):
        self._unitaries = new_unitaries


class CPEState(estimation.Estimation):
    r"""The "Cross-Platform Estimation of Quantum State" class.

    `Cross-Platform Verification of Intermediate Scale Quantum Devices` is the procedure that
    determines the overlap between the quantum states generated by different quantum platforms.
    Thus, it can be used to estimate the fidelity for mixed states as

    .. math::

        F_{\rm max}(\rho_1, \rho_2) = \frac{{\rm Tr}[\rho_1 \rho_2]}{{\rm max}({\rm Tr}[\rho_1^2], {\rm Tr}[\rho_2^2])}.
    """
    def __init__(self, qc_list: List[QuantumSnapshot] = None, qp: QProgram = None, **kwargs):
        r"""The init function of the Cross Platform Estimation class.

        Optional keywords list are:

            + `samples`: default to :math:`100`, the number of sampled unitaries
            + `shots`: default to :math:`512`, the number of shots each measurement should carry out

        :param qc_list: List[QuantumSnapshot], the target quantum device(s)
        :param qp: QProgram, quantum state preparation circuit

        """
        super().__init__(qp, **kwargs)
        self._qc_list = qc_list
        self._unitaries: Dict[str, List[List[float]]] = {}
        self._shots = kwargs.get('shots', 512)
        self._samples = kwargs.get('samples', 100)
        self._result: np.ndarray = None

    def estimate(self, qc_list: List[QuantumSnapshot] = None, qp: QProgram = None, **kwargs):
        r"""Execute the quantum cross-platform estimation procedure on the quantum computer(s).

        Optional keywords list are:

            + `samples`: default to :math:`100`, the number of sampled unitaries
            + `shots`: default to :math:`512`, the number of shots each measurement should carry out
            + `show`: default to False, when `show=True` we will visualize the results
            + `filename`: default to None, if set, we will save the result as `filename`.

        :param qc_list: List[QComputer], the quantum device(s)
        :param qp: QProgram, quantum state preparation circuit
        :return result: The matrix of which the elements in :math:`i_{th}` row and :math:`j_{th}` column
                         is the estimated fidelity :math:`F_{\rm max}(\rho_i, \rho_j)`, where :math:`i`, :math:`j` are the index
                         of quantum devices in the qc_List

        We consider the fidelity for mixed states as

        .. math::

            F_{\rm max}(\rho_1, \rho_2) = \frac{{\rm Tr}[\rho_1 \rho_2]}{{\rm max}({\rm Tr}[\rho_1^2], {\rm Tr}[\rho_2^2])},

        where :math:`{\rm Tr}[\rho_1 \rho_2], {\rm Tr}[\rho_1^2], {\rm Tr}[\rho_2^2]` are determined by

        .. math::
            {\rm Tr}[\rho_i \rho_j] = 2^{N} \sum_{s, s'} (-2)^{D[s, s']} \mathbb{E}[P^{(i)}_U(s)P^{(j)}_U(s')],

        where :math:`N` is the number of qubits, :math:`s` and :math:`s'` are the standard basis used for measurements, :math:`D[s, s']`
        is the hamming distance between two strings :math:`s` and :math:`s'`, :math:`P_U^{(i)}(s) = {\rm Tr}[U \rho_i U^{\dagger}\vert s \rangle \langle s \vert]`,
        and :math:`\mathbb{E}[\cdot]` is the average over random unitaries.

        Usage:

        .. code-block:: python
            :linenos:

            result = CPEState.estimate(qc_list, qp=qp, show=True, filename='test.png')
            result = CPEState.estimate(qc_list, qp=qp, samples=100, shots=100)

        **Examples**

            >>> import QCompute
            >>> from qcompute_qep.estimation.cpe_state import QuantumSnapshot, CPEState
            >>> ideal_baidu1 = QuantumSnapshot(qc_name='Baidu ideal1',
            >>>                                 qc=QCompute.BackendName.LocalBaiduSim2,
            >>>                                 qubits=[0])
            >>> ideal_baidu2 = QuantumSnapshot(qc_name='Baidu ideal2',
            >>>                                 qc=QCompute.BackendName.LocalBaiduSim2,
            >>>                                 qubits=[1])
            >>> qp = QCompute.QEnv()
            >>> qp.Q.createList(1)
            >>> est = CPEState()
            >>> result = est.estimate([ideal_baidu1, ideal_baidu2], qp,
            >>>                         samples=100, shots=50, show=True, filename='test.png')

        """
        self._qc_list = qc_list if qc_list is not None else self._qc_list
        self._qp = qp if qp is not None else self._qp
        self._shots = kwargs.get('shots', self._shots)
        self._samples = kwargs.get('samples', self._samples)
        self._result = np.identity(len(self._qc_list), dtype=float)

        if self._qp is None:
            raise ArgumentError("in CPEState.estimate(): the quantum program is not set!")
        if self._qc_list is None:
            raise ArgumentError("in CPEState.estimate(): the quantum computer list is not set!")
        if self._pre_verify() is False:
            raise ArgumentError("in CPEState.estimate(): the input qc_list is illegal!")

        pbar = tqdm(total=100, desc='CPEState Step 1/3 : Sampling unitaries...', ncols=80)
        # Step 1. Generate or read unitaries and construct a list of circuit
        if len(self._unitaries.keys()) != self._samples:
            self._unitaries.clear()
            self._generate_random_unitaries()
        qp_list = self._construct_circuits()
        pbar.update(100 / 3)

        # Step 2. Run circuits on all quantum computer and collect experiment data
        for quantum_snapshot in qc_list:
            pbar.desc = "CPEState Step 2/3 : Running circuits on {}..."\
                .format(quantum_snapshot.qc_name)
            pbar.update(100 / 3 / len(qc_list))
            # Save the circuit information to quantum_snapshot
            quantum_snapshot.qp = print_circuit(self._qp.circuit, show=False, num_qubits=number_of_qubits(self._qp))
            quantum_snapshot.unitaries = self._unitaries

            # The counts are None or samples less than required
            if len(quantum_snapshot.counts) < self._samples * self._shots:

                qc = quantum_snapshot.qc
                qubits = quantum_snapshot.qubits
                for i, qp in enumerate(qp_list):
                    qp_new = _mapping_qubits(qp, qubits)
                    quantum_snapshot.counts += [list(execute(qp=qp_new, qc=qc, shots=1).keys())[0]
                                                for _ in range(self._shots)]
                # quantum_snapshot.save_data()

        pbar.desc = "CPEState Step 3/3 : Processing experimental data..."
        # Only input one quantum snapshot, then return
        if len(qc_list) < 2:
            pbar.update(100 - pbar.n)
            pbar.desc = "Successfully finished CPEState!"
            print("Please input more QuantumSnapshot")
            return

        # Step 3. Process experiment data and compare each quantum devices
        for i, dev1 in enumerate(qc_list):
            for j, dev2 in enumerate(qc_list):
                if j > i:
                    purity_dev1 = self._compute_fidelity(dev1.counts, dev1.counts)
                    purity_dev2 = self._compute_fidelity(dev2.counts, dev2.counts)
                    self._result[i, j] = self._compute_fidelity(dev1.counts, dev2.counts)/max(purity_dev1, purity_dev2)
                    # self._result[j, i] = self._result[i, j]
        pbar.update(100 - pbar.n)
        pbar.desc = "Successfully finished CPEState!"

        # Step 4. Visualize experiment result
        file_name = kwargs.get('filename', None)
        if kwargs.get('show', False) is True or file_name is not None:
            self.visualize(file_name)

        return self._result

    def visualize(self, file_name: str = None):
        """Visualize the results of Cross-Platform Estimation.

        """
        fig, ax = plt.subplots()
        im = ax.imshow(self._result, cmap='Greens', vmin=0, vmax=3)

        device_name = [snap.qc_name for snap in self._qc_list]

        plt.xticks(np.arange(len(device_name)), device_name)

        if len(self._qc_list) > 4:
            ax.set_xticklabels(device_name, rotation=45)

        ax.xaxis.set_ticks_position('top')
        plt.yticks(np.arange(len(device_name)), device_name)
        ax.yaxis.set_ticks_position('right')
        for i in range(len(device_name)):
            for j in range(i, len(device_name)):
                text = ax.text(j, i, '{:.4f}'.format(self._result[i, j]),
                               ha='center', va='center', color='black')

        fig.tight_layout()
        if file_name is None:
            plt.show()
        else:
            plt.savefig(file_name, bbox_inches='tight', dpi=500)

    def _pre_verify(self) -> bool:
        """Verify if the input qc_list and qp is legal.

        """
        for quantum_snapshot in self._qc_list:
            # Verify the input unitaries is legal
            if len(quantum_snapshot.unitaries) == self._samples:
                # Have not set unitaries
                if len(self._unitaries) == 0:
                    self._unitaries = quantum_snapshot.unitaries
                else:
                    if self._unitaries != quantum_snapshot.unitaries:
                        raise ArgumentError("There exit difference between input QuantumSnapshot's unitaries!")
            # Verify the input counts is legal
            if len(quantum_snapshot.counts) > 0 and len(quantum_snapshot.counts) != self._samples * self._shots:
                raise ArgumentError("{}'s counts number is illegal!".format(quantum_snapshot.qc_name))

            # Verify the input qubits is legal
            if len(quantum_snapshot.qubits) > 0 and len(quantum_snapshot.qubits) != number_of_qubits(self._qp):
                raise ArgumentError("{}'s qubits is illegal!".format(quantum_snapshot.qc_name))

        return True

    def _generate_random_unitaries(self):
        """Generate a list of random unitaries.

        """

        n = number_of_qubits(self._qp)
        for i in range(self._samples):
            unitaries_list = []
            for j in range(n):
                u = unitary_group.rvs(2)
                alpha, theta, phi, lam = decompose_yzy(u)
                unitaries_list.append([theta, phi, lam])
            self._unitaries.update({'sample_{}'.format(i): unitaries_list})

    def _construct_circuits(self) -> List[QProgram]:
        """Construct a list of random measurement circuit.

        """
        qp_list = []
        for u3_list in self._unitaries.values():
            qp = deepcopy(self._qp)
            for i, u3_param in enumerate(u3_list):
                u3 = RotationGate.createRotationGateInstance('U', *u3_param)
                u3(qp.Q[i])
            qp_list.append(qp)
        return qp_list

    def _compute_fidelity(self, counts1: List[str], counts2: List[str]) -> float:
        """Input two outcomes of quantum devices, and return the fidelity between two devices.

        """
        fidelity = 0
        n = number_of_qubits(self._qp)

        outcomes_list = [c1 + c2 for c1, c2 in zip(counts1, counts2)]
        outcomes_dict = dict(collections.Counter(outcomes_list))

        # Process data
        for key, value in outcomes_dict.items():
            key1 = key[:n]
            key2 = key[n:]
            fidelity += (-2)**(-hamming_distance(key1, key2)) * (value / self._shots / self._samples)

        fidelity = fidelity * (2**n)
        return fidelity if fidelity < 1.0 else 1.0


def _mapping_qubits(qp: QProgram, qubits: List[int] = None) -> QProgram:
    """Remap the qubits which we are interested in.

    """
    if qubits is None:
        MeasureZ(*qp.Q.toListPair())
        return qp
    else:
        qubits.sort()
        n = qubits[-1] + 1
    qp_new = QEnv()
    qp_new.Q.createList(n)
    for gate in qp.circuit:
        qp_new.circuit.append(CircuitLine(gate.data, [qubits[i] for i in gate.qRegList]))
    qreglist, indexlist = qp_new.Q.toListPair()
    MeasureZ(qRegList=[qreglist[x] for x in qubits],
             cRegList=[indexlist[x] for x in qubits])
    return qp_new


def read_quantum_snapshot(file_name: str, qc_name: str = None) -> QuantumSnapshot:
    """Input a file_name and return a Quantum Snapshot.

    :param file_name: the name of a file, which contain the information of quantum device
    :param qc_name: the qc_name of return QuantumSnapshot class
    """
    with open(file_name + '.devinf', 'r') as f:
        file_data = json.load(f)
    if qc_name is None:
        qc_name = file_name
    print("Note that the quantum circuit of file `{}.devinf` is: \n{}".format(file_name, file_data['qp']))
    return QuantumSnapshot(qc_name,
                           qc=None,
                           qubits=file_data['qubits'],
                           counts=file_data['counts'],
                           unitaries=file_data['unitaries'],
                           qp=file_data['qp'])


def hamming_distance(bits1: str, bits2: str) -> int:
    """Compute the hamming distance between bits1 and bits2.

    """
    bits1 = [int(b) for b in bits1]
    bits2 = [int(b) for b in bits2]
    return hamming(bits1, bits2) * len(bits1)



