# -*- coding: utf-8 -*-
from math import radians
from copy import deepcopy, copy
from random import choice, seed
from itertools import permutations
import numpy as np
from .. import Crystal, Atom, Lattice, graphite
from ... import rotation_matrix, transform
import unittest

#seed(23)

try:
    import ase
    ASE = True
except ImportError:
    ASE = False

@unittest.skipIf(not ASE, 'ASE not importable')
class TestAseAtoms(unittest.TestCase):

    def setUp(self):
        name = choice(list(Crystal.builtins))
        self.crystal = Crystal.from_database(name)
    
    def test_construction(self):
        """ Test that ase_atoms returns without error """
        to_ase = self.crystal.ase_atoms()
        self.assertEqual(len(self.crystal), len(to_ase))
    
    def test_back_and_forth(self):
        """ Test conversion to and from ase Atoms """
        to_ase = self.crystal.ase_atoms()
        crystal2 = Crystal.from_ase(to_ase)
        
        # ase has different handling of coordinates which can lead to
        # rounding beyond 1e-3. Therefore, we cannot compare directly sets
        # self.assertSetEqual(set(self.crystal), set(crystal2))
        self.assertEqual(len(self.crystal), len(crystal2))

class TestCrystalMethods(unittest.TestCase):

    def setUp(self):
        name = choice(list(Crystal.builtins))
        self.crystal = Crystal.from_database(name)
    
    def test_array(self):
        """ Test Crystal.__array__ """
        arr = np.array(self.crystal)
        self.assertSequenceEqual(arr.shape, (len(self.crystal), 4))

class TestSpglibMethods(unittest.TestCase):
    
    def test_spacegroup_info_graphite(self):
        """ Test that Crystal.spacegroup_info() works correctly for graphite """
        c = Crystal.from_database('C')
        info = c.spacegroup_info()
        
        supposed = {'international_number': 194, 
                    'hall_number': 488,
                    'international_symbol': 'P6_3/mmc',
                    'international_full': 'P 6_3/m 2/m 2/c' ,
                    'hall_symbol': '-P 6c 2c',
                    'pointgroup': 'D6h'}
        
        self.assertDictEqual(info, supposed)
    
    def test_primitive(self):
        """ Test that all built-in crystal have a primitive cell """
        for name in Crystal.builtins:
            with self.subTest(name):
                c = Crystal.from_database(name)
                prim = c.primitive(symprec = 0.1)
                self.assertLessEqual(len(prim), len(c))

class TestCrystalRotations(unittest.TestCase):

    def setUp(self):
        self.crystal = Crystal.from_database(next(iter(Crystal.builtins)))
    
    def test_crystal_equality(self):
        """ Tests that Crystal.__eq__ is working properly """
        self.assertEqual(self.crystal, self.crystal)

        cryst2 = deepcopy(self.crystal)
        cryst2.transform(2*np.eye(3)) # This stretches lattice vectors, symmetry operators
        self.assertFalse(self.crystal is cryst2)
        self.assertNotEqual(self.crystal, cryst2)

        cryst2.transform(0.5*np.eye(3))
        self.assertEqual(self.crystal, cryst2)
    
    def test_trivial_rotation(self):
        """ Test rotation by 360 deg around all axes. """
        unrotated = deepcopy(self.crystal)
        r = rotation_matrix(radians(360), [0,0,1])
        self.crystal.transform(r)

        self.assertEqual(self.crystal, unrotated)
    
    def test_identity_transform(self):
        """ Tests the trivial identity transform """
        transf = deepcopy(self.crystal)
        transf.transform(np.eye(3))
        self.assertEqual(self.crystal, transf)
    
    def test_one_axis_rotation(self):
        """ Tests the crystal orientation after rotations. """
        unrotated = deepcopy(self.crystal)
        self.crystal.transform(rotation_matrix(radians(37), [0,1,0]))
        self.assertNotEqual(unrotated, self.crystal)
        self.crystal.transform(rotation_matrix(radians(-37), [0,1,0]))
        self.assertEqual(unrotated, self.crystal)

    def test_wraparound_rotation(self):
        cryst1 = deepcopy(self.crystal)
        cryst2 = deepcopy(self.crystal)

        cryst1.transform(rotation_matrix(radians(22.3), [0,0,1]))
        cryst2.transform(rotation_matrix(radians(22.3 - 360), [0,0,1]))
        self.assertEqual(cryst1, cryst2)
    
class TestCrystalConstructors(unittest.TestCase):

    def test_builtins(self):
        """ Test that all names in Crystal.builtins build without errors,
        and that Crystal.source is correctly recorded. """
        for name in Crystal.builtins:
            with self.subTest(name):
                c = Crystal.from_database(name)

                self.assertIn(name, c.source)
    
    def test_builtins_wrong_name(self):
        """ Test that a name not in Crystal.builtins will raise a ValueError """
        with self.assertRaises(ValueError):
            c = Crystal.from_database('___')
    
    def test_from_pdb(self):
        """ Test Crystal.from_pdb constructor """
        # the tests on PDBParser are also using the test_cache folder
        c = Crystal.from_pdb('1fbb', download_dir = 'test_cache')
        self.assertIn('1fbb', c.source)
    
    def test_from_cod(self):
        """ Test building a Crystal object from the COD """
        # revision = None and latest revision should give the same Crystal
        c = Crystal.from_cod(1521124, download_dir = 'test_cache')
        c2 = Crystal.from_cod(1521124, revision = 176429, download_dir = 'test_cache')

        self.assertSetEqual(set(c), set(c2))

if __name__ == '__main__':
    unittest.main()