"""
peakrdl-python is a tool to generate Python Register Access Layer (RAL) from SystemRDL
Copyright (C) 2021 - 2025

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Lesser General Public License as
published by the Free Software Foundation, either version 3 of
the License, or (at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU Lesser General Public License for more details.

You should have received a copy of the GNU Lesser General Public License
along with this program.  If not, see <https://www.gnu.org/licenses/>.

This package is intended to distributed as part of automatically generated code by the PeakRDL
Python tool. It provide the base class common to both the async and non-async versions
"""
import unittest
from abc import ABC, abstractmethod
from typing import Union, Optional

from ..lib import FieldReadWrite, FieldReadOnly, FieldWriteOnly
from ..lib import FieldEnumReadWrite, FieldEnumReadOnly, FieldEnumWriteOnly
from ..lib import FieldAsyncReadOnly, FieldAsyncWriteOnly, FieldAsyncReadWrite
from ..lib import FieldEnumAsyncReadOnly, FieldEnumAsyncWriteOnly, FieldEnumAsyncReadWrite
from ..lib import RegReadOnly, RegReadWrite, RegWriteOnly
from ..lib import RegAsyncReadOnly, RegAsyncReadWrite, RegAsyncWriteOnly
from ..lib import AddressMap, AsyncAddressMap
from ..lib import RegFile, AsyncRegFile
from ..lib import MemoryReadOnly, MemoryReadOnlyLegacy
from ..lib import MemoryWriteOnly, MemoryWriteOnlyLegacy
from ..lib import MemoryReadWrite, MemoryReadWriteLegacy
from ..lib import MemoryAsyncReadOnly, MemoryAsyncReadOnlyLegacy
from ..lib import MemoryAsyncWriteOnly, MemoryAsyncWriteOnlyLegacy
from ..lib import MemoryAsyncReadWrite, MemoryAsyncReadWriteLegacy
from ..lib.base_register import BaseReg
from ..lib import Base
from .utilities import get_field_bitmask_int, get_field_inv_bitmask
from ..sim_lib.simulator import BaseSimulator

class CommonTestBase(unittest.TestCase, ABC):
    """
    Base Test class for the autogenerated register test to be used for the async and
    non-async cases
    """

    @property
    @abstractmethod
    def simulator_instance(self) -> BaseSimulator:
        """
        Simulator configured for the DUT
        """

    # pylint:disable-next=too-many-arguments
    def _single_field_property_test(self, *,
                                    fut: Union[FieldReadWrite,
                                               FieldReadOnly,
                                               FieldWriteOnly,
                                               FieldEnumReadWrite,
                                               FieldEnumReadOnly,
                                               FieldEnumWriteOnly,
                                               FieldAsyncReadOnly,
                                               FieldAsyncWriteOnly,
                                               FieldAsyncReadWrite,
                                               FieldEnumAsyncReadOnly,
                                               FieldEnumAsyncWriteOnly,
                                               FieldEnumAsyncReadWrite],
                                    lsb: int,
                                    msb: int,
                                    low: int,
                                    high: int,
                                    is_volatile: bool,
                                    default: Optional[int]) -> None:
        self.assertEqual(fut.lsb, lsb)
        self.assertEqual(fut.msb, msb)
        self.assertEqual(fut.low, low)
        self.assertEqual(fut.high, high)
        self.assertEqual(fut.bitmask, get_field_bitmask_int(fut))
        self.assertEqual(fut.inverse_bitmask, get_field_inv_bitmask(fut))
        width = (fut.high - fut.low) + 1
        self.assertEqual(fut.width, width)
        self.assertEqual(fut.max_value, (2**width) - 1)
        self.assertEqual(fut.is_volatile, is_volatile)

        if default is None:
            self.assertIsNone(fut.default)
        else:
            if isinstance(fut, (FieldEnumReadWrite,
                                FieldEnumReadOnly,
                                FieldEnumWriteOnly,
                                FieldEnumAsyncReadOnly,
                                FieldEnumAsyncWriteOnly,
                                FieldEnumAsyncReadWrite)):
                # pylint does not realise this is a class being returned rather than an object, so
                # is unhappy with the name
                # pylint:disable-next=invalid-name
                EnumCls = fut.enum_cls
                if default in [item.value for item in fut.enum_cls]:
                    self.assertEqual(fut.default, EnumCls(default))
                else:
                    # this is a special case if the default value for the field does not map
                    # to a legal value of the encoding
                    self.assertIsNone(fut.default)
            else:
                self.assertEqual(fut.default, default)

    def _single_register_property_test(self, *,
                                       rut: BaseReg,
                                       address: int,
                                       width: int,
                                       accesswidth: Optional[int]) -> None:
        self.assertEqual(rut.address, address)
        self.assertEqual(rut.width, width)
        if accesswidth is not None:
            self.assertEqual(rut.accesswidth, accesswidth)
        else:
            self.assertEqual(rut.accesswidth, width)

    def _single_node_rdl_name_and_desc_test(self,
                                            dut: Base,
                                            rdl_name: Optional[str],
                                            rdl_desc: Optional[str]) -> None:
        """
        Check the SystemRDL Name and Desc properties for a node
        """
        if rdl_name is None:
            self.assertIsNone(dut.rdl_name)
        else:
            self.assertEqual(dut.rdl_name, rdl_name)

        if rdl_desc is None:
            self.assertIsNone(dut.rdl_desc)
        else:
            self.assertEqual(dut.rdl_desc, rdl_desc)

    def _test_node_inst_name(self,
                             dut: Base,
                             parent_full_inst_name:str,
                             inst_name:str) -> None:
        """
        Test the `inst_name` and `full_inst_name` attributes of a node
        """
        self.assertEqual(dut.inst_name, inst_name)
        full_inst_name = parent_full_inst_name + '.' + inst_name
        self.assertEqual(dut.full_inst_name, full_inst_name)

    def _test_field_iterators(self, *,
                              rut: Union[RegReadOnly,
                                            RegReadWrite,
                                            RegWriteOnly,
                                            RegAsyncReadOnly,
                                            RegAsyncReadWrite,
                                            RegAsyncWriteOnly],
                              has_sw_readable: bool,
                              has_sw_writable: bool,
                              readable_fields: set[str],
                              writeable_fields: set[str]) -> None:
        if has_sw_readable:
            if not isinstance(rut, (RegReadOnly,
                                    RegReadWrite,
                                    RegAsyncReadOnly,
                                    RegAsyncReadWrite,
                                    )):
                raise TypeError(f'Register was expected to readable, got {type(rut)}')

            child_readable_field_names = { field.inst_name for field in rut.readable_fields}

            self.assertEqual(readable_fields, child_readable_field_names)
        else:
            self.assertFalse(hasattr(rut, 'readable_fields'))
            # check the readable_fields is empty
            self.assertFalse(readable_fields)

        if has_sw_writable:
            if not isinstance(rut, (RegWriteOnly,
                                    RegReadWrite,
                                    RegAsyncWriteOnly,
                                    RegAsyncReadWrite,
                                    )):
                raise TypeError(f'Register was expected to writable, got {type(rut)}')

            child_writeable_fields_names = {field.inst_name for field in rut.writable_fields}

            self.assertEqual(writeable_fields, child_writeable_fields_names)
        else:
            self.assertFalse(hasattr(rut, 'writeable_fields'))
            # check the writeable_fields is empty
            self.assertFalse(writeable_fields)

        child_field_names = {field.inst_name for field in rut.fields}
        self.assertEqual(readable_fields | writeable_fields, child_field_names)

    def _test_register_iterators(self,
                                 dut: Union[AddressMap, AsyncAddressMap, RegFile, AsyncRegFile,
                                            MemoryReadOnly, MemoryReadOnlyLegacy,
                                            MemoryWriteOnly, MemoryWriteOnlyLegacy,
                                            MemoryReadWrite, MemoryReadWriteLegacy,
                                            MemoryAsyncReadOnly, MemoryAsyncReadOnlyLegacy,
                                            MemoryAsyncWriteOnly, MemoryAsyncWriteOnlyLegacy,
                                            MemoryAsyncReadWrite, MemoryAsyncReadWriteLegacy],
                                 readable_registers: set[str],
                                 writeable_registers: set[str]) -> None:

        if isinstance(dut, (AddressMap, AsyncAddressMap, RegFile, AsyncRegFile,
                            MemoryReadOnly, MemoryReadOnlyLegacy,
                            MemoryReadWrite, MemoryReadWriteLegacy,
                            MemoryAsyncReadOnly, MemoryAsyncReadOnlyLegacy,
                            MemoryAsyncReadWrite, MemoryAsyncReadWriteLegacy)):
            child_readable_reg_names = { reg.inst_name for reg in
                                         dut.get_readable_registers(unroll=True)}
            self.assertEqual(readable_registers, child_readable_reg_names)
        else:
            self.assertFalse(hasattr(dut, 'get_readable_registers'))

        if isinstance(dut, (AddressMap, AsyncAddressMap, RegFile, AsyncRegFile,
                            MemoryWriteOnly, MemoryWriteOnlyLegacy,
                            MemoryReadWrite, MemoryReadWriteLegacy,
                            MemoryAsyncWriteOnly, MemoryAsyncWriteOnlyLegacy,
                            MemoryAsyncReadWrite, MemoryAsyncReadWriteLegacy)):
            child_writable_reg_names = {reg.inst_name for reg in
                                        dut.get_writable_registers(unroll=True)}
            self.assertEqual(writeable_registers, child_writable_reg_names)
        else:
            self.assertFalse(hasattr(dut, 'get_writable_registers'))

        child_reg_names = {field.inst_name for field in dut.get_registers(unroll=True)}
        self.assertEqual(readable_registers | writeable_registers, child_reg_names)


    def _test_memory_iterators(self,
                               dut: Union[AddressMap, AsyncAddressMap],
                               memories: set[str]) -> None:
        child_mem_names = {reg.inst_name for reg in dut.get_memories(unroll=True)}
        self.assertEqual(memories, child_mem_names)

    def __test_section_iterators(self,
                                 dut: Union[AddressMap, AsyncAddressMap, RegFile, AsyncRegFile],
                                 sections: set[str]) -> None:
        child_section_names = {reg.inst_name for reg in dut.get_sections(unroll=True)}
        self.assertEqual(sections, child_section_names)

    def _test_addrmap_iterators(self, *,
                                dut: Union[AddressMap, AsyncAddressMap],
                                memories: set[str],
                                sections: set[str],
                                readable_registers: set[str],
                                writeable_registers: set[str]) -> None:
        self._test_register_iterators(dut=dut,
                                      readable_registers=readable_registers,
                                      writeable_registers=writeable_registers)
        self._test_memory_iterators(dut=dut,
                                    memories=memories)
        self.__test_section_iterators(dut=dut,
                                      sections=sections)

    def _test_regfile_iterators(self,
                                dut: Union[RegFile, AsyncRegFile],
                                sections: set[str],
                                readable_registers: set[str],
                                writeable_registers: set[str]) -> None:
        self._test_register_iterators(dut=dut,
                                      readable_registers=readable_registers,
                                      writeable_registers=writeable_registers)
        self.__test_section_iterators(dut=dut,
                                      sections=sections)
        self.assertFalse(hasattr(dut, 'get_memories'))
