import cotengra as ctg  # type: ignore
from rust_circuit import EinsumSpec, Einsum, ScalarConstant, einsum_nest_optimize
from interp.tools.perf_timer import catchtime

examples = [
    (
        "a,a,bca,def,dgh,h,h,bch,ig,ig,bi,ig,bjk,lgk,f,f,mnf,omie,ie,bjp,mnq,pr,sr,tuv,uvs,s,s,i,i,u->iauq",
        {
            0: 384,
            1: 8,
            2: 48,
            3: 16384,
            4: 31,
            5: 384,
            6: 31,
            7: 384,
            8: 31,
            9: 48,
            10: 384,
            11: 16384,
            12: 8,
            13: 48,
            14: 16384,
            15: 384,
            16: 384,
            17: 2,
            18: 384,
            19: 16384,
            20: 8,
            21: 48,
        },
    ),
    (
        "a,a,bca,def,dgh,h,h,bch,ig,ig,ig,b,i,ij,bkl,mjl,f,f,nof,pnie,ie,bkq,nor,r,r,rs,tis,su,qv,wv,xyz,yzw,w,w,i,i,y->iauy",
        {
            0: 384,
            1: 8,
            2: 48,
            3: 16384,
            4: 31,
            5: 384,
            6: 31,
            7: 384,
            8: 31,
            9: 31,
            10: 48,
            11: 384,
            12: 16384,
            13: 8,
            14: 48,
            15: 16384,
            16: 384,
            17: 384,
            18: 768,
            19: 16384,
            20: 384,
            21: 2,
            22: 384,
            23: 16384,
            24: 8,
            25: 48,
        },
    ),
    (
        "a,a,bca,def,dgh,h,h,bch,ig,ig,ig,b,i,ij,bkl,mjl,f,f,nof,pnqe,qe,bkr,nos,s,s,st,uqt,tv,iw,wx,qy,yz,AB,Bx,CD,Dz,EFC,C,v,v,FGv,FGH,IJA,A,r,r,JKr,JKL->iaLH",
        {
            0: 384,
            1: 8,
            2: 48,
            3: 16384,
            4: 31,
            5: 384,
            6: 31,
            7: 384,
            8: 31,
            9: 31,
            10: 48,
            11: 384,
            12: 16384,
            13: 8,
            14: 48,
            15: 16384,
            16: 31,
            17: 384,
            18: 384,
            19: 768,
            20: 16384,
            21: 384,
            22: 8,
            23: 2,
            24: 8,
            25: 2,
            26: 31,
            27: 8,
            28: 31,
            29: 8,
            30: 16384,
            31: 8,
            32: 48,
            33: 384,
            34: 16384,
            35: 8,
            36: 48,
            37: 384,
        },
    ),
    (
        "a,a,bca,def,dgh,h,h,bch,ig,ig,bi,ig,bjk,lgk,f,f,mnf,ompe,pe,bjq,mnr,r,r,rs,tps,su,pv,vw,ix,xy,zA,Aw,BC,Cy,DEB,B,q,q,EFq,EFG,HIz,z,u,u,IJu,IJK->iaGK",
        {
            0: 384,
            1: 8,
            2: 48,
            3: 16384,
            4: 31,
            5: 384,
            6: 31,
            7: 384,
            8: 31,
            9: 48,
            10: 384,
            11: 16384,
            12: 8,
            13: 48,
            14: 16384,
            15: 31,
            16: 384,
            17: 384,
            18: 768,
            19: 16384,
            20: 384,
            21: 8,
            22: 2,
            23: 8,
            24: 2,
            25: 31,
            26: 8,
            27: 31,
            28: 8,
            29: 16384,
            30: 8,
            31: 48,
            32: 384,
            33: 16384,
            34: 8,
            35: 48,
            36: 384,
        },
    ),
]

for example in examples:
    ints = EinsumSpec.string_to_ints(example[0])
    rusty = Einsum.from_einsum_string(
        example[0], [ScalarConstant(1, tuple([example[1][i] for i in ints_here])) for ints_here in ints[0]]
    )
    with catchtime("rust"):
        optimized = einsum_nest_optimize(rusty)
        assert optimized is not None
    print("rust", optimized.total_flops())
    opt = ctg.HyperOptimizer(methods=["kahypar"])
    with catchtime("search"):
        tree = opt.search(*ints, example[1])
        print("cotengra", tree.contraction_cost())
    print("ratio", optimized.total_flops() / tree.contraction_cost())
