import unittest
import tensorflow as tf
from kgcnn.layers.conv.acsf_conv import ACSFG2, ACSFG4
from kgcnn.graph.adj import get_angle_indices
import numpy as np


class ACSFTest(unittest.TestCase):
    positions = [
        [
            [- 2.51593, 1.12614, 0.00000],
            [- 1.80565, 0.90562, 0.59108],
            [- 2.22951, 1.80475, - 0.59996]
        ],
        [
            [- 0.89480, 1.35710, - 0.08980],
            [0.41190, 1.95030, - 0.56090],
            [- 1.33170, 2.01250, 0.69270],
            [0.36860, 3.07540, - 1.04050],
            [2.82390, 1.94820, - 0.95600],
            [- 1.60060, 1.30400, - 0.94510],
            [- 0.80410, 0.33990, 0.33540],
            [1.61520, 1.29220, - 0.47380],
            [1.70940, - 0.04710, 0.09010],
            [2.99110, 2.89170, - 0.39360],
            [3.72680, 1.31430, - 0.83000],
            [2.72120, 2.18050, - 2.03760],
            [2.75090, - 0.43150, 0.08540],
            [1.36380, - 0.03950, 1.14580],
            [1.09270, - 0.75370, - 0.50520],
        ],
        [
            [- 0.8067180, 0.0475439, 1.5251484],
            [0.3116624, 0.2691858, 0.0786324],
            [- 0.2255019, - 0.2220590, 2.4060774],
            [- 1.4979938, - 0.7535269, 1.2757967],
            [- 1.3445395, 0.9720244, 1.7264612],
            [- 1.7790900, 0.0209031, - 1.8054100],
            [- 0.5677430, - 1.0588900, - 0.9579280],
            [- 2.5303600, 0.4056500, - 1.1116500],
            [- 2.2931400, - 0.5487300, - 2.5840100],
            [- 1.2769100, 0.8703580, - 2.2718800],
            [2.5277800, 0.5632240, 1.8126600],
            [1.2848000, 1.6235700, 0.9897200],
            [3.2810100, 1.1954700, 2.2894300],
            [2.0753200, - 0.0677200, 2.5815300],
            [3.0240100, - 0.0818562, 1.0861300]
        ]
    ]

    atomic_number = [
        [8, 1, 16],
        [6, 6, 1, 8, 6, 1, 1, 7, 6, 1, 1, 1, 1, 1, 1],
        [6, 16, 1, 1, 1, 6, 16, 1, 1, 1, 6, 16, 1, 1, 1]
    ]

    edge_index = [
        [
            [0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1]
        ],
        [
            [0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6], [0, 7], [0, 8], [0, 9], [0, 10], [0, 11], [0, 12], [0, 13],
            [0, 14], [1, 0], [1, 2], [1, 3], [1, 4], [1, 5], [1, 6], [1, 7], [1, 8], [1, 9], [1, 10], [1, 11], [1, 12],
            [1, 13], [1, 14], [2, 0], [2, 1], [2, 3], [2, 4], [2, 5], [2, 6], [2, 7], [2, 8], [2, 9], [2, 10], [2, 11],
            [2, 12], [2, 13], [2, 14], [3, 0], [3, 1], [3, 2], [3, 4], [3, 5], [3, 6], [3, 7], [3, 8], [3, 9], [3, 10],
            [3, 11], [3, 12], [3, 13], [3, 14], [4, 0], [4, 1], [4, 2], [4, 3], [4, 5], [4, 6], [4, 7], [4, 8], [4, 9],
            [4, 10], [4, 11], [4, 12], [4, 13], [4, 14], [5, 0], [5, 1], [5, 2], [5, 3], [5, 4], [5, 6], [5, 7], [5, 8],
            [5, 9], [5, 10], [5, 11], [5, 12], [5, 13], [5, 14], [6, 0], [6, 1], [6, 2], [6, 3], [6, 4], [6, 5], [6, 7],
            [6, 8], [6, 9], [6, 10], [6, 11], [6, 12], [6, 13], [6, 14], [7, 0], [7, 1], [7, 2], [7, 3], [7, 4], [7, 5],
            [7, 6], [7, 8], [7, 9], [7, 10], [7, 11], [7, 12], [7, 13], [7, 14], [8, 0], [8, 1], [8, 2], [8, 3], [8, 4],
            [8, 5], [8, 6], [8, 7], [8, 9], [8, 10], [8, 11], [8, 12], [8, 13], [8, 14], [9, 0], [9, 1], [9, 2], [9, 3],
            [9, 4], [9, 5], [9, 6], [9, 7], [9, 8], [9, 10], [9, 11], [9, 12], [9, 13], [9, 14], [10, 0], [10, 1],
            [10, 2], [10, 3], [10, 4], [10, 5], [10, 6], [10, 7], [10, 8], [10, 9], [10, 11], [10, 12], [10, 13],
            [10, 14], [11, 0], [11, 1], [11, 2], [11, 3], [11, 4], [11, 5], [11, 6], [11, 7], [11, 8], [11, 9],
            [11, 10], [11, 12], [11, 13], [11, 14], [12, 0], [12, 1], [12, 2], [12, 3], [12, 4], [12, 5], [12, 6],
            [12, 7], [12, 8], [12, 9], [12, 10], [12, 11], [12, 13], [12, 14], [13, 0], [13, 1], [13, 2], [13, 3],
            [13, 4], [13, 5], [13, 6], [13, 7], [13, 8], [13, 9], [13, 10], [13, 11], [13, 12], [13, 14], [14, 0],
            [14, 1], [14, 2], [14, 3], [14, 4], [14, 5], [14, 6], [14, 7], [14, 8], [14, 9], [14, 10], [14, 11],
            [14, 12], [14, 13]
        ],
        [
            [0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6], [0, 7], [0, 8], [0, 9], [0, 10], [0, 11], [0, 12], [0, 13],
            [0, 14], [1, 0], [1, 2], [1, 3], [1, 4], [1, 5], [1, 6], [1, 7], [1, 8], [1, 9], [1, 10], [1, 11], [1, 12],
            [1, 13], [1, 14], [2, 0], [2, 1], [2, 3], [2, 4], [2, 5], [2, 6], [2, 7], [2, 8], [2, 9], [2, 10], [2, 11],
            [2, 12], [2, 13], [2, 14], [3, 0], [3, 1], [3, 2], [3, 4], [3, 5], [3, 6], [3, 7], [3, 8], [3, 9], [3, 10],
            [3, 11], [3, 12], [3, 13], [3, 14], [4, 0], [4, 1], [4, 2], [4, 3], [4, 5], [4, 6], [4, 7], [4, 8], [4, 9],
            [4, 10], [4, 11], [4, 12], [4, 13], [4, 14], [5, 0], [5, 1], [5, 2], [5, 3], [5, 4], [5, 6], [5, 7],
            [5, 8], [5, 9], [5, 10], [5, 11], [5, 12], [5, 13], [5, 14], [6, 0], [6, 1], [6, 2], [6, 3], [6, 4],
            [6, 5], [6, 7], [6, 8], [6, 9], [6, 10], [6, 11], [6, 12], [6, 13], [6, 14], [7, 0], [7, 1], [7, 2],
            [7, 3], [7, 4], [7, 5], [7, 6], [7, 8], [7, 9], [7, 10], [7, 11], [7, 12], [7, 13], [7, 14], [8, 0],
            [8, 1], [8, 2], [8, 3], [8, 4], [8, 5], [8, 6], [8, 7], [8, 9], [8, 10], [8, 11], [8, 12], [8, 13],
            [8, 14], [9, 0], [9, 1], [9, 2], [9, 3], [9, 4], [9, 5], [9, 6], [9, 7], [9, 8], [9, 10], [9, 11], [9, 12],
            [9, 13], [9, 14], [10, 0], [10, 1], [10, 2], [10, 3], [10, 4], [10, 5], [10, 6], [10, 7], [10, 8], [10, 9],
            [10, 11], [10, 12], [10, 13], [10, 14], [11, 0], [11, 1], [11, 2], [11, 3], [11, 4], [11, 5], [11, 6],
            [11, 7], [11, 8], [11, 9], [11, 10], [11, 12], [11, 13], [11, 14], [12, 0], [12, 1], [12, 2], [12, 3],
            [12, 4], [12, 5], [12, 6], [12, 7], [12, 8], [12, 9], [12, 10], [12, 11], [12, 13], [12, 14], [13, 0],
            [13, 1], [13, 2], [13, 3], [13, 4], [13, 5], [13, 6], [13, 7], [13, 8], [13, 9], [13, 10], [13, 11],
            [13, 12], [13, 14], [14, 0], [14, 1], [14, 2], [14, 3], [14, 4], [14, 5], [14, 6], [14, 7], [14, 8],
            [14, 9], [14, 10], [14, 11], [14, 12], [14, 13]],

    ]

    def test_acsf_g2(self):
        g2_kwargs = {"eta": [0.0, 0.3], "rs": [0.0, 3.0], "rc": 10.0, "elements": [1, 6, 16]}
        layer = ACSFG2(**ACSFG2.make_param_table(**g2_kwargs))
        positions = tf.ragged.constant(self.positions, ragged_rank=1, inner_shape=(3,))
        atomic_number = tf.ragged.constant(self.atomic_number, ragged_rank=1, dtype="int64")
        edge_index = tf.ragged.constant(self.edge_index, ragged_rank=1, inner_shape=(2,))
        out = layer([atomic_number[2:], positions[2:], edge_index[2:]])
        # Expected result for last molecule, first atom.
        expected_result = np.array([7.011673, 2.1447349, 7.011673, 4.2706203, 1.4739769, 0.04355875,
                                    1.4739769, 1.3946176, 2.579667, 0.5183595, 2.579667, 2.230977])
        is_as_expected = np.all(np.abs(out[0][0] - expected_result) < 1e-04)
        # print(is_as_expected)
        self.assertTrue(is_as_expected)

    def test_acsf_g4(self):
        positions = tf.ragged.constant(self.positions, ragged_rank=1, inner_shape=(3,))
        atomic_number = tf.ragged.constant(self.atomic_number, ragged_rank=1, dtype="int64")
        # edge_index = tf.ragged.constant(self.edge_index, ragged_rank=1, inner_shape=(2,))
        angle_index = [get_angle_indices(np.array(x), edge_pairing="ik")[1].tolist() for x in self.edge_index]
        angle_index = tf.ragged.constant(angle_index, ragged_rank=1, inner_shape=(3,))
        g4_kwargs = {"eta": [0.0, 0.3], "lamda": [-1.0, 1.0], "rc": 6.0, "zeta": [1.0, 8.0],
                     "elements": [1, 6, 16],
                     "multiplicity": 2.0}
        layer = ACSFG4(**ACSFG4.make_param_table(**g4_kwargs))

        out = layer([atomic_number[2:], positions[2:], angle_index[2:]])
        # Expected result for last molecule, first atom.
        expected_result = np.array(
            [4.093878746032715, 3.8475711345672607, 0.45441314578056335, 0.9100052118301392, 0.51732337474823,
             0.2600725293159485, 0.031431298702955246, 0.002341042272746563, 0.6844168305397034, 2.004915952682495,
             0.13527904450893402, 1.1940642595291138, 0.0003382707363925874, 0.001611050684005022,
             1.046786655933829e-05, 0.0008169701904989779, 4.2289838790893555, 4.576600551605225, 0.44851353764533997,
             0.6955477595329285, 0.0896565243601799, 0.054230786859989166, 0.00514655327424407, 0.001017893897369504,
             0.001709476695396006, 0.0008070105686783791, 0.00011411488230805844, 2.8149503350505256e-07,
             1.0225409408093356e-10, 4.827216057434747e-11, 6.82589791980992e-12, 1.683791046045227e-14,
             0.27127137780189514, 1.2421965599060059, 0.0007643443532288074, 0.4701008200645447, 0.00020793949079234153,
             0.0013072892324998975, 9.969490122330171e-09, 0.000549450283870101, 0.30053770542144775,
             1.2993861436843872, 0.0007059765048325062, 0.4033553898334503, 0.0027985533233731985, 0.016471944749355316,
             1.2086698575330956e-07, 0.005516418721526861])
        is_as_expected = np.all(np.abs(out[0][0] - expected_result) < 1e-04)
        self.assertTrue(is_as_expected)


if __name__ == "__main__":
    # ACSFTest().test_acsf_g2()
    # ACSFTest().test_acsf_g4()
    unittest.main()
