#  This file is part of Pynguin.
#
#  SPDX-FileCopyrightText: 2019–2022 Pynguin Contributors
#
#  SPDX-License-Identifier: LGPL-3.0-or-later
#
"""Provides an assertion visitor to transform assertions to AST."""
from __future__ import annotations

import ast
from typing import Any

import pynguin.assertion.assertion as ass
import pynguin.configuration as config
import pynguin.testcase.variablereference as vr
import pynguin.utils.ast_util as au
import pynguin.utils.namingscope as ns
import pynguin.utils.type_utils as tu


class AssertionToAstVisitor(ass.AssertionVisitor):
    """An assertion visitor that transforms assertions into AST nodes."""

    def __init__(
        self,
        variable_names: ns.AbstractNamingScope,
        module_aliases: ns.AbstractNamingScope,
        common_modules: set[str],
    ):
        """Create a new assertion visitor.

        Args:
            variable_names: the naming scope that is used to resolve the names
                            of the variables used in the assertions.
            module_aliases: the naming scope that is used to resolve the aliases of the
                            modules used in the assertions.
            common_modules: the set of common modules that are used. Modules may be
                            added when transforming the assertions.
        """
        self._common_modules = common_modules
        self._module_aliases = module_aliases
        self._variable_names = variable_names
        self._nodes: list[ast.stmt] = []

    @property
    def nodes(self) -> list[ast.stmt]:
        """Provides the ast nodes generated by this visitor.

        Returns:
            the ast nodes generated by this visitor.
        """
        return self._nodes

    def visit_float_assertion(self, assertion: ass.FloatAssertion) -> None:
        """Creates an assertion of form
        "assert float_0 == pytest.approx(1, rel=0.01, abs=0.01)"

        Args:
            assertion: the assertion that is visited.

        """
        left = au.create_full_name(
            self._variable_names, self._module_aliases, assertion.source, load=True
        )
        comp = self._construct_float_comparator(au.create_ast_constant(assertion.value))
        self._nodes.append(
            au.create_ast_assert(au.create_ast_compare(left, ast.Eq(), comp))
        )

    def visit_not_none_assertion(self, assertion: ass.NotNoneAssertion) -> None:
        """
        Creates an assertion of form "assert var0 is None" or "assert var0 is not None".

        Args:
            assertion: the assertion that is visited.
        """
        self._nodes.append(
            self._create_constant_assert(assertion.source, ast.IsNot(), None)
        )

    def visit_object_assertion(self, assertion: ass.ObjectAssertion) -> None:
        """
        Creates an assertion of form "assert var0 == value" or "assert var0 is False",
        if the value is a bool.

        Args:
            assertion: the assertion that is visited.
        """
        if isinstance(assertion.object, (bool, type(None))):
            self._nodes.append(
                self._create_constant_assert(
                    assertion.source, ast.Is(), assertion.object
                )
            )
        else:
            comp = self._create_assertable_object(assertion.object)
            left = au.create_full_name(
                self._variable_names, self._module_aliases, assertion.source, load=True
            )
            self._nodes.append(
                au.create_ast_assert(au.create_ast_compare(left, ast.Eq(), comp))
            )

    def _create_assertable_object(self, value: Any):
        """Recursively constructs an assertable object. Must handle all cases that
        `pynguin.utils.type_utils.is_assertable` accepts.

        Args:
            value: The value that should be generated.

        Returns:
            An assertion representation of the given value.

        Raises:
            AssertionError: If we encounter an object that we can't construct.
        """
        tp_ = type(value)
        if tu.is_enum(tp_):
            enum_attr = self._construct_enum_attr(value)
            return au.create_ast_attribute(value.name, enum_attr)
        if tu.is_primitive_type(tp_) or tu.is_none_type(tp_):
            return au.create_ast_constant(value)
        if tu.is_set(value) or tu.is_list(value) or tu.is_tuple(value):
            elements = [self._create_assertable_object(v) for v in value]
            if tu.is_set(value):
                return au.create_ast_set(elements)
            if tu.is_list(value):
                return au.create_ast_list(elements)
            return au.create_ast_tuple(elements)
        if tu.is_dict(value):
            keys = [self._create_assertable_object(v) for v in value]
            values = [self._create_assertable_object(v) for v in value.values()]
            return au.create_ast_dict(keys, values)
        raise AssertionError(f"Cannot create assertion object of type {type(value)}")

    def _create_constant_assert(
        self, var: vr.Reference, operator: ast.cmpop, value: Any
    ) -> ast.Assert:
        left = au.create_full_name(
            self._variable_names, self._module_aliases, var, load=True
        )
        comp = au.create_ast_constant(value)
        return au.create_ast_assert(au.create_ast_compare(left, operator, comp))

    def _construct_float_comparator(self, comp):
        self._common_modules.add("pytest")
        float_precision = config.configuration.test_case_output.float_precision
        func = au.create_ast_attribute("approx", au.create_ast_name("pytest"))
        keywords = [
            au.create_ast_keyword("abs", au.create_ast_constant(float_precision)),
            au.create_ast_keyword("rel", au.create_ast_constant(float_precision)),
        ]
        comp_float = au.create_ast_call(func, [comp], keywords)
        return comp_float

    def _construct_enum_attr(self, value) -> ast.Attribute:
        module = self._module_aliases.get_name(value.__class__.__module__)
        enum_name = value.__class__.__name__
        return au.create_ast_attribute(enum_name, au.create_ast_name(module))

    def visit_collection_length_assertion(
        self, assertion: ass.CollectionLengthAssertion
    ) -> None:
        left = au.create_ast_call(
            au.create_ast_name("len"),
            [
                au.create_full_name(
                    self._variable_names, self._module_aliases, assertion.source, True
                )
            ],
            {},
        )
        self._nodes.append(
            au.create_ast_assert(
                au.create_ast_compare(
                    left, ast.Eq(), au.create_ast_constant(assertion.length)
                )
            )
        )
