# Copyright 2020-2021 Terry Yue Zhuo
# Copyright 2020-2021 Data61/CSIRO

# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------

import sys
import math
import pytest
import pyarma as pa

def test_expr_elem_1():

    A = pa.mat("\
        0.061198   0.201990   0.019678  -0.493936  -0.126745   0.051408;\
        0.437242   0.058956  -0.149362  -0.045465   0.296153   0.035437;\
        -0.492474  -0.031309   0.314156   0.419733   0.068317  -0.454499;\
        0.336352   0.411541   0.458476  -0.393139  -0.135040   0.373833;\
        0.239585  -0.428913  -0.406953  -0.291020  -0.353768   0.258704;\
    ")

    A_plus_2 = pa.mat("\
        2.061198   2.201990   2.019678   1.506064   1.873255   2.051408;\
        2.437242   2.058956   1.850638   1.954535   2.296153   2.035437;\
        1.507526   1.968691   2.314156   2.419733   2.068317   1.545501;\
        2.336352   2.411541   2.458476   1.606861   1.864960   2.373833;\
        2.239585   1.571087   1.593047   1.708980   1.646232   2.258704;\
    ")

    A_minus_2 = pa.mat("\
        -1.938802 -1.798010  -1.980322  -2.493936  -2.126745  -1.948592;\
        -1.562758 -1.941044  -2.149362  -2.045465  -1.703847  -1.964563;\
        -2.492474 -2.031309  -1.685844  -1.580267  -1.931683  -2.454499;\
        -1.663648 -1.588459  -1.541524  -2.393139  -2.135040  -1.626167;\
        -1.760415 -2.428913  -2.406953  -2.291020  -2.353768  -1.741296;\
    ")

    two_minus_A = pa.mat("\
        1.938802   1.798010   1.980322   2.493936   2.126745   1.948592;\
        1.562758   1.941044   2.149362   2.045465   1.703847   1.964563;\
        2.492474   2.031309   1.685844   1.580267   1.931683   2.454499;\
        1.663648   1.588459   1.541524   2.393139   2.135040   1.626167;\
        1.760415   2.428913   2.406953   2.291020   2.353768   1.741296;\
    ")

    A_times_2 = pa.mat("\
        0.122396   0.403980   0.039356  -0.987872  -0.253490   0.102816;\
        0.874484   0.117912  -0.298724  -0.090930   0.592306   0.070874;\
        -0.984948  -0.062618   0.628312   0.839466   0.136634  -0.908998;\
        0.672704   0.823082   0.916952  -0.786278  -0.270080   0.747666;\
        0.479170  -0.857826  -0.813906  -0.582040  -0.707536   0.517408;\
    ")

    A_elem_mul_A = pa.mat("\
        3.74519520400000e-03   4.07999601000000e-02   3.87223684000000e-04   2.43972772096000e-01   1.60642950250000e-02   2.64278246400000e-03;\
        1.91180566564000e-01   3.47580993600000e-03   2.23090070440000e-02   2.06706622500000e-03   8.77065994090000e-02   1.25578096900000e-03;\
        2.42530640676000e-01   9.80253481000000e-04   9.86939923360000e-02   1.76175791289000e-01   4.66721248900000e-03   2.06569341001000e-01;\
        1.13132667904000e-01   1.69365994681000e-01   2.10200242576000e-01   1.54558273321000e-01   1.82358016000000e-02   1.39751111889000e-01;\
        5.74009722250000e-02   1.83966361569000e-01   1.65610744209000e-01   8.46926404000000e-02   1.25151797824000e-01   6.69277596160000e-02;\
    ")

    A_div_2 = pa.mat("\
        0.03059900000000000   0.10099500000000000   0.00983900000000000  -0.24696799999999999  -0.06337250000000000   0.02570400000000000;\
        0.21862100000000001   0.02947800000000000  -0.07468100000000000  -0.02273250000000000   0.14807650000000000   0.01771850000000000;\
        -0.24623700000000001  -0.01565450000000000   0.15707800000000000   0.20986650000000001   0.03415850000000000  -0.22724949999999999;\
        0.16817599999999999   0.20577050000000000   0.22923800000000000  -0.19656950000000001  -0.06752000000000000   0.18691650000000001;\
        0.11979250000000000  -0.21445649999999999  -0.20347650000000000  -0.14551000000000000  -0.17688400000000001   0.12935199999999999;\
    ")

    two_div_A = pa.mat("\
        32.68080656230595     9.90148027130056   101.63634515702815    -4.04910757669010   -15.77971517614107    38.90445066915655;\
        4.57412599887476    33.92360404369360   -13.39028668603795   -43.98988232706478     6.75326604829261    56.43818607669949;\
        -4.06112810016366   -63.87939570091667     6.36626389437095     4.76493389845449    29.27529019131402    -4.40044972596199;\
        5.94615165065170     4.85978310787990     4.36227850530889    -5.08725921366234   -14.81042654028436     5.34998247880738;\
        8.34776801552685    -4.66295029528133    -4.91457244448376    -6.87237990516116    -5.65342258203116     7.73084297111757;\
    ")

    assert math.isclose(pa.accu(pa.abs(A - A)), 0.0) == True
    assert math.isclose(pa.accu(pa.abs(2*A - 2*A)), 0.0) == True
    assert math.isclose(pa.accu(pa.abs(2*A - (A+A))), 0.0) == True

    assert math.isclose(pa.accu(pa.abs((A+0) - A)), 0.0) == True
    assert math.isclose(pa.accu(pa.abs((A-0) - A)), 0.0) == True

    assert math.isclose(pa.accu(pa.abs((A*1) - A)), 0.0) == True
    assert math.isclose(pa.accu(pa.abs((A/1) - A)), 0.0) == True

    assert math.isclose(pa.accu(pa.abs((A+2) - A_plus_2)), 0.0, abs_tol=0.0001) == True
    assert math.isclose(pa.accu(pa.abs((2+A) - A_plus_2)), 0.0, abs_tol=0.0001) == True

    assert math.isclose(pa.accu(pa.abs((A*2) - A_times_2)), 0.0) == True
    assert math.isclose(pa.accu(pa.abs((2*A) - A_times_2)), 0.0) == True
    assert math.isclose(pa.accu(pa.abs((A+A) - A_times_2)), 0.0) == True

    assert math.isclose(pa.accu(pa.abs((A+A+A - A) - A_times_2)), 0.0, abs_tol=0.0001) == True
    assert math.isclose(pa.accu(pa.abs((3*A - A) - A_times_2)), 0.0, abs_tol=0.0001) == True

    assert math.isclose(pa.accu(pa.abs((A/2) - A_div_2)), 0.0) == True
    assert math.isclose(pa.accu(pa.abs((A*0.5) - A_div_2)), 0.0) == True

    # TypeError: unsupported operand type(s) for /: 'int' and 'pyarma.mat'
    # assert math.isclose(pa.accu(pa.abs((2/A) - two_div_A)), 0.0) == True

    assert math.isclose(pa.accu(pa.abs((A @ A) - A_elem_mul_A)), 0.0, abs_tol=0.0001) == True
    assert math.isclose(pa.accu(pa.abs((A @ A) - pa.square(A))), 0.0) == True

    # TypeError: unsupported operand type(s) for /: 'int' and 'pyarma.mat'
    # assert math.isclose(pa.accu(pa.abs(((2/A) @ (A/2) ) - pa.ones(pa.size(A)))), 0.0) == True


    assert math.isclose(pa.accu(pa.abs(A - A[:, :])), 0.0) == True

    assert math.isclose(pa.accu(pa.abs((A[:,:]+2) - A_plus_2)), 0.0, abs_tol=0.0001) == True
    assert math.isclose(pa.accu(pa.abs((2+A[:,:]) - A_plus_2)), 0.0, abs_tol=0.0001) == True

    assert math.isclose(pa.accu(pa.abs((A+A+A) / (3*A)) - pa.ones(pa.size(A))), 0.0) == True

    with pytest.raises(RuntimeError):
        math.isclose(pa.accu(pa.abs(A + pa.randu(A.n_rows+1, A.n_cols))), 0.0)
    with pytest.raises(RuntimeError):    
        math.isclose(pa.accu(pa.abs(A + pa.randu(A.n_rows, A.n_cols+1))), 0.0)