#(PP) - waiting fo Jan's reply

import sys
import os
import numpy as np
import pickle
from qrem.functions_qrem import ancillary_functions as anf


def TVD(p,q):
    """
    Calculation of TVD between two probability distributions. 
    """
    res =0.
    for p1,q1 in zip(p,q):
        res+=0.5*np.abs(p1-q1)
    return(res)

def Overlap(s1,s2):
    """
    Function used to create dictionary of overlapes between different Pauli states. It's needed to compute coherence indicator
    Input: s1,s2 two symbols, each corresponding to one of the Pauli states, as in the SeparableCircuitsCreator class 

    Output: value of squared scalar product between the states     
    """
    if s1==s2:
        return 1
    elif (s1=='2' and s2=='3') or (s2=='2' and s1=='3') or (s1=='4' and s2 =='5') or (s2=='4' and s1 =='5'):
        return 0
    else:
        return 0.5

def compute_indicator_normalization(dim,setting1,setting2,overlap_dic):
    """
    A function computing normalization factor of coherence indicator. As for now works for two-qubit reduced POVMs
    
    Input: 
    dim - dimension of the reduced subspace
    setting1,setting2 - strings consisting of two symbols each, corresponding to an inout Pauli eigenstate
    overlap_dic - a dictionary consisting of squared scalar product of two Pauli eigenstates   
    """
    return dim*np.sqrt((2*(1- overlap_dic[setting1[1]+setting2[1]]*overlap_dic[setting1[0]+setting2[0]])))

def compute_pauli_marginals(marginals_dictionary, subsets_list):
    """
    Dictionaries needed to perform computations

    setting dictionary - keys correspond to alphabet enocding input states
                         items correspond to unnormalized probability distributions for a given input state on subset of qubits
    normalization dictionary
                         keys correspond to alphabet enocding input states
                         items correspond tonumber of times that a particular input state appear in marginals

    """
    setting_dictionary = {}
    normalisation_dictionary = {}
    for i in range(2, 6):
        for j in range(2, 6):
            setting_dictionary[str(i) + str(j)] = np.array([0., 0., 0., 0.])
            normalisation_dictionary[str(i) + str(j)] = 0

    measurement_settings = marginals_dictionary.keys()
    """
    Computation of marginal probability distributions for X,Y Pauli input states  
    """
    Pauli_subset_dictionary = {}
    for subset in subsets_list:
        setting_dictionary = setting_dictionary.fromkeys(setting_dictionary, [0., 0., 0., 0.])
        normalisation_dictionary = normalisation_dictionary.fromkeys(normalisation_dictionary, 0)
        for setting in measurement_settings:
            s1 = setting[subset[0]]
            s2 = setting[subset[1]]
            if s1 != '0' and s1 != '1' and s2 != '0' and s2 != '1':
                setting_dictionary[s1 + s2] += marginals_dictionary[setting][subset]
                normalisation_dictionary[s1 + s2] += 1

        Pauli_subset_dictionary[subset] = [setting_dictionary, normalisation_dictionary]
    return Pauli_subset_dictionary

# TODO_JT (PP) not sure why this is defined outside of function scope. @Jan - are you still using the compute_coherence_indicator function? 
"""
Creation of overlap dicitionary for the Pauli case
"""
settings_list = ['2','3','4','5']
overlap_dictionary={}
for i in settings_list:
    for j in settings_list:
        overlap_dictionary[i+j] = Overlap(i,j)

def compute_coherence_indicator(marginals_dictionary, subset_list):
    """
    Computation of coherence indicator 

    TVD between probability distributions generated by different input X,Y Pauli eigenstates of two qubits
    """
    pauli_subset_dictionary = compute_pauli_marginals(marginals_dictionary,subset_list)


    # generation of aplphabet corresponding to pairs od input states
    settings_list = []
    for i in range(2, 6):
        for j in range(2, 6):
            settings_list.append(str(i) + str(j))

    # dictionary storing values of coherence indicator
    indicator_dic = {}

    for keys, elements in pauli_subset_dictionary.items():
        tvd_value = []
        tvd_settings = []

        for i in range(len(settings_list)):
            for j in range(i + 1, len(settings_list)):
                s1 = settings_list[i]
                s2 = settings_list[j]
                indicator = TVD(elements[0][s1] / elements[1][s1], elements[0][s2] / elements[1][s2])
                indicator = indicator / (compute_indicator_normalization(2, s1, s2, overlap_dictionary))
                tvd_value.append(indicator)
                tvd_settings.append((settings_list[i], settings_list[j]))
        indicator_dic[keys] = [tvd_value, tvd_settings]

    return  indicator_dic



