#!/usr/bin/env python
# -*- coding: utf-8 -*-


import numpy as np
import mindspore as ms
from mindspore.ops._primitive_cache import _get_cache_prim
from msadapter.pytorch.common._inner import _out_inplace_assign
from msadapter.utils import unsupported_attr
from msadapter.pytorch.tensor import cast_to_ms_tensor, cast_to_adapter_tensor

def eigh(A, UPLO='L', *, out=None): # TODO use numpy api now
    A = A.numpy()
    l, q = np.linalg.eigh(A, UPLO=UPLO)
    return _out_inplace_assign(out, (ms.Tensor(l), ms.Tensor(q)), "eigh")


def solve(A, B, *, left=True, out=None):# TODO use numpy api now
    unsupported_attr(left)
    A = A.numpy()
    B = B.numpy()
    output = np.linalg.solve(A, B)
    return _out_inplace_assign(out, ms.Tensor(output), "solve")

def eig(A, *, out=None):
    input = cast_to_ms_tensor(A)
    output = _get_cache_prim(ms.ops.Eig)(compute_v=True)(input)
    return _out_inplace_assign(out, output, "eig")

def slogdet(A, *, out=None):
    A = cast_to_ms_tensor(A)
    sign, output = ms.ops.slogdet(A)
    return _out_inplace_assign(out, (sign, output), "slogdet")

def det(A, *, out=None):
    A = cast_to_ms_tensor(A)
    output = ms.ops.det(A)
    return _out_inplace_assign(out, output, "det")

def cholesky(A, *, upper=False, out=None):
    # TODO: ms.ops.cholesky to support complex type
    A = cast_to_ms_tensor(A)
    output = ms.ops.cholesky(A, upper)
    return _out_inplace_assign(out, output, "cholesky")

def inv(A, *, out=None):
    A = cast_to_ms_tensor(A)
    output = ms.ops.inverse(A)
    return _out_inplace_assign(out, output, "inv")

def matmul(input, other, *, out=None):
    input = cast_to_ms_tensor(input)
    other = cast_to_ms_tensor(other)
    output = ms.ops.matmul(input, other)
    return _out_inplace_assign(out, output, "matmul")

def diagonal(A, *, offset=0, dim1=-2, dim2=-1):
    A = cast_to_ms_tensor(A)
    output = ms.ops.diagonal(A, offset=offset, dim1=dim1, dim2=dim2)
    return cast_to_adapter_tensor(output)

def multi_dot(tensors, *, out=None):
    input = cast_to_ms_tensor(tensors)
    output = ms.numpy.multi_dot(input)
    return _out_inplace_assign(out, output, "multi_dot")

def householder_product(A, tau, *, out=None):
    input = cast_to_ms_tensor(A)
    input2 = cast_to_ms_tensor(tau)
    output = ms.ops.orgqr(input, input2)
    return _out_inplace_assign(out, output, "householder_product")
