import numpy as nm
import pytest

import sfepy.base.testing as tst

test_bases = {
    '2_3_P1'
    : nm.array([[[ 1. ,  0. ,  0. ]],

                [[ 0. ,  1. ,  0. ]],

                [[ 0. ,  0. ,  1. ]],

                [[ 0.6,  0.2,  0.2]]]),
    '2_3_P1_grad'
    : nm.array([[[-1.,  1.,  0.],
                 [-1.,  0.,  1.]],

                [[-1.,  1.,  0.],
                 [-1.,  0.,  1.]],

                [[-1.,  1.,  0.],
                 [-1.,  0.,  1.]],

                [[-1.,  1.,  0.],
                 [-1.,  0.,  1.]]]),
    '2_4_Q1'
    : nm.array([[[ 1.  ,  0.  ,  0.  ,  0.  ]],

                [[ 0.  ,  1.  ,  0.  ,  0.  ]],

                [[ 0.  ,  0.  ,  1.  ,  0.  ]],

                [[ 0.  ,  0.  ,  0.  ,  1.  ]],

                [[ 0.64,  0.16,  0.04,  0.16]]]),
    '2_4_Q1_grad'
    : nm.array([[[-1. ,  1. ,  0. , -0. ],
                 [-1. , -0. ,  0. ,  1. ]],

                [[-1. ,  1. ,  0. , -0. ],
                 [-0. , -1. ,  1. ,  0. ]],

                [[-0. ,  0. ,  1. , -1. ],
                 [-0. , -1. ,  1. ,  0. ]],

                [[-0. ,  0. ,  1. , -1. ],
                 [-1. , -0. ,  0. ,  1. ]],

                [[-0.8,  0.8,  0.2, -0.2],
                 [-0.8, -0.2,  0.2,  0.8]]]),
    '3_4_P0' : nm.ones((5, 1, 1)),
    '3_4_P0_grad' : nm.zeros((5, 3, 1)),
    '3_8_Q0' : nm.ones((9, 1, 1)),
    '3_8_Q0_grad' : nm.zeros((9, 3, 1)),
    '3_8_Q1'
    : nm.array([[[ 1.   ,  0.   ,  0.   ,  0.   ,  0.   ,  0.   ,  0.   ,
                   0.   ]],

                [[ 0.   ,  1.   ,  0.   ,  0.   ,  0.   ,  0.   ,  0.   ,
                   0.   ]],

                [[ 0.   ,  0.   ,  1.   ,  0.   ,  0.   ,  0.   ,  0.   ,
                   0.   ]],

                [[ 0.   ,  0.   ,  0.   ,  1.   ,  0.   ,  0.   ,  0.   ,
                   0.   ]],

                [[ 0.   ,  0.   ,  0.   ,  0.   ,  1.   ,  0.   ,  0.   ,
                   0.   ]],

                [[ 0.   ,  0.   ,  0.   ,  0.   ,  0.   ,  1.   ,  0.   ,
                   0.   ]],

                [[ 0.   ,  0.   ,  0.   ,  0.   ,  0.   ,  0.   ,  1.   ,
                   0.   ]],

                [[ 0.   ,  0.   ,  0.   ,  0.   ,  0.   ,  0.   ,  0.   ,
                   1.   ]],

                [[ 0.512,  0.128,  0.032,  0.128,  0.128,  0.032,  0.008,
                   0.032]]]),
    '3_8_Q1_grad'
    : nm.array([[[-1.  ,  1.  ,  0.  , -0.  , -0.  ,  0.  ,  0.  , -0.  ],
                 [-1.  , -0.  ,  0.  ,  1.  , -0.  , -0.  ,  0.  ,  0.  ],
                 [-1.  , -0.  , -0.  , -0.  ,  1.  ,  0.  ,  0.  ,  0.  ]],

                [[-1.  ,  1.  ,  0.  , -0.  , -0.  ,  0.  ,  0.  , -0.  ],
                 [-0.  , -1.  ,  1.  ,  0.  , -0.  , -0.  ,  0.  ,  0.  ],
                 [-0.  , -1.  , -0.  , -0.  ,  0.  ,  1.  ,  0.  ,  0.  ]],

                [[-0.  ,  0.  ,  1.  , -1.  , -0.  ,  0.  ,  0.  , -0.  ],
                 [-0.  , -1.  ,  1.  ,  0.  , -0.  , -0.  ,  0.  ,  0.  ],
                 [-0.  , -0.  , -1.  , -0.  ,  0.  ,  0.  ,  1.  ,  0.  ]],

                [[-0.  ,  0.  ,  1.  , -1.  , -0.  ,  0.  ,  0.  , -0.  ],
                 [-1.  , -0.  ,  0.  ,  1.  , -0.  , -0.  ,  0.  ,  0.  ],
                 [-0.  , -0.  , -0.  , -1.  ,  0.  ,  0.  ,  0.  ,  1.  ]],

                [[-0.  ,  0.  ,  0.  , -0.  , -1.  ,  1.  ,  0.  , -0.  ],
                 [-0.  , -0.  ,  0.  ,  0.  , -1.  , -0.  ,  0.  ,  1.  ],
                 [-1.  , -0.  , -0.  , -0.  ,  1.  ,  0.  ,  0.  ,  0.  ]],

                [[-0.  ,  0.  ,  0.  , -0.  , -1.  ,  1.  ,  0.  , -0.  ],
                 [-0.  , -0.  ,  0.  ,  0.  , -0.  , -1.  ,  1.  ,  0.  ],
                 [-0.  , -1.  , -0.  , -0.  ,  0.  ,  1.  ,  0.  ,  0.  ]],

                [[-0.  ,  0.  ,  0.  , -0.  , -0.  ,  0.  ,  1.  , -1.  ],
                 [-0.  , -0.  ,  0.  ,  0.  , -0.  , -1.  ,  1.  ,  0.  ],
                 [-0.  , -0.  , -1.  , -0.  ,  0.  ,  0.  ,  1.  ,  0.  ]],

                [[-0.  ,  0.  ,  0.  , -0.  , -0.  ,  0.  ,  1.  , -1.  ],
                 [-0.  , -0.  ,  0.  ,  0.  , -1.  , -0.  ,  0.  ,  1.  ],
                 [-0.  , -0.  , -0.  , -1.  ,  0.  ,  0.  ,  0.  ,  1.  ]],

                [[-0.64,  0.64,  0.16, -0.16, -0.16,  0.16,  0.04, -0.04],
                 [-0.64, -0.16,  0.16,  0.64, -0.16, -0.04,  0.04,  0.16],
                 [-0.64, -0.16, -0.04, -0.16,  0.64,  0.16,  0.04,  0.16]]]),
    '3_4_P2'
    : nm.array([[[ 1.  , -0.  , -0.  , -0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,
                   0.  ]],

                [[-0.  ,  1.  , -0.  , -0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,
                   0.  ]],

                [[-0.  , -0.  ,  1.  , -0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,
                   0.  ]],

                [[-0.  , -0.  , -0.  ,  1.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,
                   0.  ]],

                [[-0.08, -0.12, -0.12, -0.12,  0.32,  0.16,  0.32,  0.32,  0.16,
                  0.16]]]),
    '3_4_P2_grad'
    : nm.array([[[-3. , -1. ,  0. ,  0. ,  4. ,  0. ,  0. ,  0. ,  0. ,  0. ],
                 [-3. ,  0. , -1. ,  0. ,  0. ,  0. ,  4. ,  0. ,  0. ,  0. ],
                 [-3. ,  0. ,  0. , -1. ,  0. ,  0. ,  0. ,  4. ,  0. ,  0. ]],

                [[ 1. ,  3. ,  0. ,  0. , -4. ,  0. ,  0. ,  0. ,  0. ,  0. ],
                 [ 1. ,  0. , -1. ,  0. , -4. ,  4. ,  0. ,  0. ,  0. ,  0. ],
                 [ 1. ,  0. ,  0. , -1. , -4. ,  0. ,  0. ,  0. ,  4. ,  0. ]],

                [[ 1. , -1. ,  0. ,  0. ,  0. ,  4. , -4. ,  0. ,  0. ,  0. ],
                 [ 1. ,  0. ,  3. ,  0. ,  0. ,  0. , -4. ,  0. ,  0. ,  0. ],
                 [ 1. ,  0. ,  0. , -1. ,  0. ,  0. , -4. ,  0. ,  0. ,  4. ]],

                [[ 1. , -1. ,  0. ,  0. ,  0. ,  0. ,  0. , -4. ,  4. ,  0. ],
                [ 1. ,  0. , -1. ,  0. ,  0. ,  0. ,  0. , -4. ,  0. ,  4. ],
                 [ 1. ,  0. ,  0. ,  3. ,  0. ,  0. ,  0. , -4. ,  0. ,  0. ]],

                [[-0.6, -0.2,  0. ,  0. ,  0.8,  0.8, -0.8, -0.8,  0.8,  0. ],
                 [-0.6,  0. , -0.2,  0. , -0.8,  0.8,  0.8, -0.8,  0. ,  0.8],
                 [-0.6,  0. ,  0. , -0.2, -0.8,  0. , -0.8,  0.8,  0.8,  0.8]]]),

    '3_4_P2B'
    : nm.array([[[ 1.     , -0.     , -0.     , -0.     ,  0.     ,  0.     ,
                   0.     ,  0.     ,  0.     ,  0.     ,  0.     ]],

                [[-0.     ,  1.     , -0.     , -0.     ,  0.     ,  0.     ,
                  0.     ,  0.     ,  0.     ,  0.     ,  0.     ]],

                [[-0.     , -0.     ,  1.     , -0.     ,  0.     ,  0.     ,
                  0.     ,  0.     ,  0.     ,  0.     ,  0.     ]],

                [[-0.     , -0.     , -0.     ,  1.     ,  0.     ,  0.     ,
                  0.     ,  0.     ,  0.     ,  0.     ,  0.     ]],

                [[-0.16192, -0.20192, -0.20192, -0.20192,  0.23808,  0.07808,
                  0.23808,  0.23808,  0.07808,  0.07808,  0.8192 ]]]),
    '3_4_P2B_grad'
    : nm.array([[[-3.    , -1.    ,  0.    ,  0.    ,  4.    ,  0.    ,  0.    ,
                  0.    ,  0.    ,  0.    ,  0.    ],
                 [-3.    ,  0.    , -1.    ,  0.    ,  0.    ,  0.    ,  4.    ,
                  0.    ,  0.    ,  0.    ,  0.    ],
                 [-3.    ,  0.    ,  0.    , -1.    ,  0.    ,  0.    ,  0.    ,
                  4.    ,  0.    ,  0.    ,  0.    ]],

                [[ 1.    ,  3.    ,  0.    ,  0.    , -4.    ,  0.    ,  0.    ,
                   0.    ,  0.    ,  0.    ,  0.    ],
                 [ 1.    ,  0.    , -1.    ,  0.    , -4.    ,  4.    ,  0.    ,
                   0.    ,  0.    ,  0.    ,  0.    ],
                 [ 1.    ,  0.    ,  0.    , -1.    , -4.    ,  0.    ,  0.    ,
                   0.    ,  4.    ,  0.    ,  0.    ]],

                [[ 1.    , -1.    ,  0.    ,  0.    ,  0.    ,  4.    , -4.    ,
                   0.    ,  0.    ,  0.    ,  0.    ],
                 [ 1.    ,  0.    ,  3.    ,  0.    ,  0.    ,  0.    , -4.    ,
                   0.    ,  0.    ,  0.    ,  0.    ],
                 [ 1.    ,  0.    ,  0.    , -1.    ,  0.    ,  0.    , -4.    ,
                   0.    ,  0.    ,  4.    ,  0.    ]],

                [[ 1.    , -1.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ,
                   -4.    ,  4.    ,  0.    ,  0.    ],
                 [ 1.    ,  0.    , -1.    ,  0.    ,  0.    ,  0.    ,  0.    ,
                   -4.    ,  0.    ,  4.    ,  0.    ],
                 [ 1.    ,  0.    ,  0.    ,  3.    ,  0.    ,  0.    ,  0.    ,
                   -4.    ,  0.    ,  0.    ,  0.    ]],

                [[-0.8048, -0.4048, -0.2048, -0.2048,  0.5952,  0.5952, -1.0048,
                  -1.0048,  0.5952, -0.2048,  2.048 ],
                 [-0.8048, -0.2048, -0.4048, -0.2048, -1.0048,  0.5952,  0.5952,
                  -1.0048, -0.2048,  0.5952,  2.048 ],
                 [-0.8048, -0.2048, -0.2048, -0.4048, -1.0048, -0.2048, -1.0048,
                  0.5952,  0.5952,  0.5952,  2.048 ]]]),
}

@pytest.fixture(scope='module')
def gels():
    from sfepy.discrete.fem.geometry_element import GeometryElement

    gels = {}
    for key in ['2_3', '2_4', '3_4', '3_8']:
        gel = GeometryElement(key)
        gels[key] = gel

    return gels

def test_base_functions_values(gels):
    """
    Compare base function values and their gradients with correct
    data. Also test that sum of values over all element nodes gives one.
    """
    from sfepy.base.base import ordered_iteritems
    from sfepy.discrete import PolySpace

    ok = True
    for key, val in ordered_iteritems(test_bases):
        gel = gels[key[:3]]
        diff = key[-4:] == 'grad'
        order = int(key[5])
        force_bubble = key[6:7] == 'B'

        ps = PolySpace.any_from_args('aux', gel, order,
                                     base='lagrange',
                                     force_bubble=force_bubble)
        dim = ps.geometry.dim
        coors = nm.r_[ps.geometry.coors, [[0.2] * dim]]

        bf = ps.eval_base(coors, diff=diff)
        _ok = nm.allclose(val, bf, rtol=0.0, atol=1e-14)

        if not diff:
            _ok = _ok and nm.allclose(bf.sum(axis=2), 1.0,
                                      rtol=0.0, atol=1e-14)

        tst.report('%s: %s' % (key, _ok))

        ok = ok and _ok

    assert ok

def test_base_functions_delta(gels):
    """
    Test :math:`\delta` property of base functions evaluated in the
    reference element nodes.
    """
    from sfepy.base.base import ordered_iteritems
    from sfepy.discrete import PolySpace

    ok = True
    for key, gel in ordered_iteritems(gels):
        for order in range(11):
            ps = PolySpace.any_from_args('aux', gel, order,
                                         base='lagrange',
                                         force_bubble=False)
            bf = ps.eval_base(ps.node_coors)
            _ok = nm.allclose(nm.eye(ps.n_nod),
                              bf.squeeze(),
                              rtol=0.0, atol=(order + 1) * 1e-14)

            tst.report('%s order %d (n_nod: %d): %s'
                       % (key, order, ps.n_nod, _ok))

        ok = ok and _ok

    assert ok
