import pytest
import json

from qutrunk.circuit import QCircuit
from qutrunk.circuit.gates import (
    H, Measure, CNOT, Toffoli, P, R, Rx, Ry, Rz, S, Sdg, T, Tdg, X, Y, Z, MCX, MCZ,
    NOT, Swap, SqrtSwap, SqrtX, All, CP, CX, CY, CZ, CRx, CRy, CRz, Rxx, Ryy, Rzz, 
    U1, U2, U3, Barrier, iSwap, CR, CU, CU1, CU3, X1, Y1
)
from numpy import pi
from qutrunk.backends import BackendQuSprout, backend
import math

PRECISION = 0.0000000001

def check_all_state(res, resbox):
    if len(res) != len(resbox):
        return False

    for index in range(len(res)):
        ampstr = res[index]
        ampstrbox = resbox[index]
        realstr, imagstr = ampstr.split(',')
        realstrbox, imagstrbox = ampstrbox.split(',')
        test = float(imagstr) - float(imagstrbox)
        test1 = math.fabs(test)
        a = test1 > PRECISION
        if (math.fabs(float(realstr) - float(realstrbox)) > PRECISION 
            or math.fabs(float(imagstr) - float(imagstrbox)) > PRECISION):
            return False
    
    return True

def test_h_inverse_gate():
    circuit = QCircuit()
    qr = circuit.allocate(1)
    orgres = circuit.get_all_state()

    H | qr[0]
    H | qr[0]
    circuit.cmds[1].inverse = True

    finalres = circuit.get_all_state()

    assert check_all_state(orgres, finalres)

def test_p_inverse_gate():
    circuit = QCircuit()
    qr = circuit.allocate(1)
    orgres = circuit.get_all_state()

    P(pi/2) | qr[0]
    P(pi/2) | qr[0]
    circuit.cmds[1].inverse = True

    finalres = circuit.get_all_state()

    assert check_all_state(orgres, finalres)

def test_cp_inverse_gate():
    circuit = QCircuit()
    qr = circuit.allocate(2)
    orgres = circuit.get_all_state()

    CP(pi/2) | (qr[0], qr[1])
    CP(pi/2) | (qr[0], qr[1])
    circuit.cmds[1].inverse = True

    finalres = circuit.get_all_state()

    assert check_all_state(orgres, finalres)

def test_r_inverse_gate():
    circuit = QCircuit()
    qr = circuit.allocate(1)
    orgres = circuit.get_all_state()

    R(pi/2, pi/2) | qr[0]
    R(pi/2, pi/2) | qr[0]
    circuit.cmds[1].inverse = True

    finalres = circuit.get_all_state()

    assert check_all_state(orgres, finalres)

def test_rx_inverse_gate():
    circuit = QCircuit()
    qr = circuit.allocate(1)
    orgres = circuit.get_all_state()

    Rx(pi/2) | qr[0]
    Rx(pi/2) | qr[0]
    circuit.cmds[1].inverse = True

    finalres = circuit.get_all_state()

    assert check_all_state(orgres, finalres)

def test_rxx_inverse_gate():
    circuit = QCircuit()
    qr = circuit.allocate(2)
    orgres = circuit.get_all_state()

    Rxx(pi/2) | (qr[0], qr[1])
    Rxx(pi/2) | (qr[0], qr[1])
    circuit.cmds[1].inverse = True

    finalres = circuit.get_all_state()

    assert check_all_state(orgres, finalres)

def test_ryy_inverse_gate():
    circuit = QCircuit()
    qr = circuit.allocate(2)
    orgres = circuit.get_all_state()

    Ryy(pi/2) | (qr[0], qr[1])
    Ryy(pi/2) | (qr[0], qr[1])
    circuit.cmds[1].inverse = True

    finalres = circuit.get_all_state()

    assert check_all_state(orgres, finalres)

def test_rzz_inverse_gate():
    circuit = QCircuit()
    qr = circuit.allocate(2)
    orgres = circuit.get_all_state()

    Rzz(pi/2) | (qr[0], qr[1])
    Rzz(pi/2) | (qr[0], qr[1])
    circuit.cmds[1].inverse = True

    finalres = circuit.get_all_state()

    assert check_all_state(orgres, finalres)

def test_not_inverse_gate():
    test_x_inverse_gate()

def test_x_inverse_gate():
    circuit = QCircuit()
    qr = circuit.allocate(1)
    orgres = circuit.get_all_state()

    X | qr[0]
    X | qr[0]
    circuit.cmds[1].inverse = True

    finalres = circuit.get_all_state()

    assert check_all_state(orgres, finalres)

def test_y_inverse_gate():
    circuit = QCircuit()
    qr = circuit.allocate(1)
    orgres = circuit.get_all_state()

    Y | qr[0]
    Y | qr[0]
    circuit.cmds[1].inverse = True

    finalres = circuit.get_all_state()

    assert check_all_state(orgres, finalres)

def test_z_inverse_gate():
    circuit = QCircuit()
    qr = circuit.allocate(1)
    orgres = circuit.get_all_state()

    Z | qr[0]
    Z | qr[0]
    circuit.cmds[1].inverse = True

    finalres = circuit.get_all_state()

    assert check_all_state(orgres, finalres)

def test_s_inverse_gate():
    circuit = QCircuit()
    qr = circuit.allocate(1)
    orgres = circuit.get_all_state()

    S | qr[0]
    S | qr[0]
    circuit.cmds[1].inverse = True

    finalres = circuit.get_all_state()

    assert check_all_state(orgres, finalres)

def test_t_inverse_gate():
    circuit = QCircuit()
    qr = circuit.allocate(1)
    orgres = circuit.get_all_state()

    T | qr[0]
    T | qr[0]
    circuit.cmds[1].inverse = True

    finalres = circuit.get_all_state()

    assert check_all_state(orgres, finalres)

def test_sdg_inverse_gate():
    circuit = QCircuit()
    qr = circuit.allocate(1)
    orgres = circuit.get_all_state()

    Sdg | qr[0]
    Sdg | qr[0]
    circuit.cmds[1].inverse = True

    finalres = circuit.get_all_state()

    assert check_all_state(orgres, finalres)

def test_tdg_inverse_gate():
    circuit = QCircuit()
    qr = circuit.allocate(1)
    orgres = circuit.get_all_state()

    Tdg | qr[0]
    Tdg | qr[0]
    circuit.cmds[1].inverse = True

    finalres = circuit.get_all_state()

    assert check_all_state(orgres, finalres)

def test_sqrtswap_inverse_gate():
    circuit = QCircuit()
    qr = circuit.allocate(2)
    orgres = circuit.get_all_state()

    SqrtSwap | (qr[0], qr[1])
    SqrtSwap | (qr[0], qr[1])
    circuit.cmds[1].inverse = True

    finalres = circuit.get_all_state()

    assert check_all_state(orgres, finalres)

def test_swap_inverse_gate():
    circuit = QCircuit()
    qr = circuit.allocate(2)
    orgres = circuit.get_all_state()

    Swap | (qr[0], qr[1])
    Swap | (qr[0], qr[1])
    circuit.cmds[1].inverse = True

    finalres = circuit.get_all_state()

    assert check_all_state(orgres, finalres)

def test_cnot_inverse_gate():
    test_cx_inverse_gate()

def test_cx_inverse_gate():
    circuit = QCircuit()
    qr = circuit.allocate(2)
    orgres = circuit.get_all_state()

    CX | (qr[0], qr[1])
    CX | (qr[0], qr[1])
    circuit.cmds[1].inverse = True

    finalres = circuit.get_all_state()

    assert check_all_state(orgres, finalres)

def test_cy_inverse_gate():
    circuit = QCircuit()
    qr = circuit.allocate(2)
    orgres = circuit.get_all_state()

    CY | (qr[0], qr[1])
    CY | (qr[0], qr[1])
    circuit.cmds[1].inverse = True

    finalres = circuit.get_all_state()

    assert check_all_state(orgres, finalres)

def test_cz_inverse_gate():
    circuit = QCircuit()
    qr = circuit.allocate(2)
    orgres = circuit.get_all_state()

    CZ | (qr[0], qr[1])
    CZ | (qr[0], qr[1])
    circuit.cmds[1].inverse = True

    finalres = circuit.get_all_state()

    assert check_all_state(orgres, finalres)

def test_u3_inverse_gate():
    circuit = QCircuit()
    qr = circuit.allocate(1)
    orgres = circuit.get_all_state()

    U3(pi, 0, pi) | qr[0]
    U3(pi, 0, pi) | qr[0]
    circuit.cmds[1].inverse = True

    finalres = circuit.get_all_state()

    assert check_all_state(orgres, finalres)

def test_u2_inverse_gate():
    circuit = QCircuit()
    qr = circuit.allocate(1)
    orgres = circuit.get_all_state()

    U2(0, pi) | qr[0]
    U2(0, pi) | qr[0]
    circuit.cmds[1].inverse = True

    finalres = circuit.get_all_state()

    assert check_all_state(orgres, finalres)

def test_u1_inverse_gate():
    circuit = QCircuit()
    qr = circuit.allocate(1)
    orgres = circuit.get_all_state()

    U1(pi/2) | qr[0]
    U1(pi/2) | qr[0]
    circuit.cmds[1].inverse = True

    finalres = circuit.get_all_state()

    assert check_all_state(orgres, finalres)

def test_crx_inverse_gate():
    circuit = QCircuit()
    qr = circuit.allocate(2)
    orgres = circuit.get_all_state()

    CRx(pi/2) | (qr[0], qr[1])
    CRx(pi/2) | (qr[0], qr[1])
    circuit.cmds[1].inverse = True

    finalres = circuit.get_all_state()

    assert check_all_state(orgres, finalres)

def test_cry_inverse_gate():
    circuit = QCircuit()
    qr = circuit.allocate(2)
    orgres = circuit.get_all_state()

    CRy(pi/2) | (qr[0], qr[1])
    CRy(pi/2) | (qr[0], qr[1])
    circuit.cmds[1].inverse = True

    finalres = circuit.get_all_state()

    assert check_all_state(orgres, finalres)

def test_crz_inverse_gate():
    circuit = QCircuit()
    qr = circuit.allocate(2)
    orgres = circuit.get_all_state()

    CRz(pi/2) | (qr[0], qr[1])
    CRz(pi/2) | (qr[0], qr[1])
    circuit.cmds[1].inverse = True

    finalres = circuit.get_all_state()

    assert check_all_state(orgres, finalres)

def test_x1_inverse_gate():
    circuit = QCircuit()
    qr = circuit.allocate(1)
    orgres = circuit.get_all_state()

    X1 | qr[0]
    X1 | qr[0]
    circuit.cmds[1].inverse = True

    finalres = circuit.get_all_state()

    assert check_all_state(orgres, finalres)

def test_y1_inverse_gate():
    circuit = QCircuit()
    qr = circuit.allocate(1)
    orgres = circuit.get_all_state()

    Y1 | qr[0]
    Y1 | qr[0]
    circuit.cmds[1].inverse = True

    finalres = circuit.get_all_state()

    assert check_all_state(orgres, finalres)

def test_cu1_inverse_gate():
    circuit = QCircuit()
    qr = circuit.allocate(2)
    orgres = circuit.get_all_state()

    CU1(pi/2) | (qr[0], qr[1])
    CU1(pi/2) | (qr[0], qr[1])
    circuit.cmds[1].inverse = True

    finalres = circuit.get_all_state()

    assert check_all_state(orgres, finalres)

def test_cu3_inverse_gate():
    circuit = QCircuit()
    qr = circuit.allocate(2)
    orgres = circuit.get_all_state()

    CU3(pi/2,pi/2,pi/2) | (qr[0], qr[1])
    CU3(pi/2,pi/2,pi/2) | (qr[0], qr[1])

    circuit.cmds[1].inverse = True
    finalres = circuit.get_all_state()

    assert check_all_state(orgres, finalres)

def test_u_inverse_gate():
    test_u3_inverse_gate()

def test_cu_inverse_gate():
    circuit = QCircuit()
    qr = circuit.allocate(2)
    orgres = circuit.get_all_state()

    CU(pi/2,pi/2,pi/2,pi/2) | (qr[0], qr[1])
    CU(pi/2,pi/2,pi/2,pi/2) | (qr[0], qr[1])

    circuit.cmds[1].inverse = True
    finalres = circuit.get_all_state()

    assert check_all_state(orgres, finalres)

def test_cr_inverse_gate():
    circuit = QCircuit()
    qr = circuit.allocate(2)
    orgres = circuit.get_all_state()

    CR(pi/2) * (qr[0], qr[1])
    CR(pi/2) * (qr[0], qr[1])
    circuit.cmds[1].inverse = True

    finalres = circuit.get_all_state()

    assert check_all_state(orgres, finalres)

def test_iswap_inverse_gate():
    circuit = QCircuit()
    qr = circuit.allocate(2)
    orgres = circuit.get_all_state()

    iSwap(pi/2) * (qr[0], qr[1])
    iSwap(pi/2) * (qr[0], qr[1])
    circuit.cmds[1].inverse = True

    finalres = circuit.get_all_state()

    assert check_all_state(orgres, finalres)