#  This file is part of Pynguin.
#
#  SPDX-FileCopyrightText: 2019–2023 Pynguin Contributors
#
#  SPDX-License-Identifier: MIT
#
"""Provides an assertion visitor to transform assertions to AST."""
from __future__ import annotations

import ast

from typing import Any

import pynguin.assertion.assertion as ass
import pynguin.configuration as config
import pynguin.testcase.variablereference as vr
import pynguin.utils.ast_util as au
import pynguin.utils.namingscope as ns
import pynguin.utils.type_utils as tu


class PyTestAssertionToAstVisitor(ass.AssertionVisitor):
    """An assertion visitor that transforms assertions into AST nodes. Uses pytest for
    some special assertions constructs, such as float assertions and raised exceptions.
    """

    def __init__(
        self,
        variable_names: ns.AbstractNamingScope,
        module_aliases: ns.AbstractNamingScope,
        common_modules: set[str],
        statement_node: ast.stmt,
    ):
        """Create a new assertion visitor.

        Args:
            variable_names: the naming scope that is used to resolve the names
                            of the variables used in the assertions.
            module_aliases: the naming scope that is used to resolve the aliases of the
                            modules used in the assertions.
            common_modules: the set of common modules that are used. Modules may be
                            added when transforming the assertions.
            statement_node: the ast representation of the statement for which we
                            generate assertions
        """
        self._common_modules = common_modules
        self._module_aliases = module_aliases
        self._variable_names = variable_names
        self._statement_node = statement_node
        self._assertion_nodes: list[ast.stmt] = []
        self._has_seen_exception = False

    @property
    def nodes(self) -> list[ast.stmt]:
        """Provides the ast nodes generated by this visitor.
        Some assertions need to wrap the original AST statement, e.g.,
        ExceptionAssertion, which wraps the given statement in a 'pytest.raises' block:

        with pytest.raises(AssertionError):
            module_0.foo()

        Other assertions do not modify the original AST statement but only
        add statements which are equivalent to the assertions:

        var_0 = module_0.foo()
        assert var_0 == 42

        Returns:
            the AST statements generated by this visitor.
        """
        if self._has_seen_exception:
            assert (
                len(self._assertion_nodes) == 0
            ), "An exception assertion was seen but there are other assertions?"
        return [self._statement_node] + self._assertion_nodes

    @property
    def assertion_nodes(self) -> list[ast.stmt]:
        """Provides the raw assertion statements. These are only the statements that are
        added in addition to the statement node. In case of an ExceptionAssertion, this
        list is empty as such an assertion instead modifies the original statement.

        Returns:
            The assertion nodes
        """
        assert len(self._assertion_nodes) > 0
        return self._assertion_nodes

    def visit_float_assertion(self, assertion: ass.FloatAssertion) -> None:
        """Creates an assertion of form
        "assert float_0 == pytest.approx(1, rel=0.01, abs=0.01)".

        Args:
            assertion: the assertion that is visited.

        """
        left = au.create_full_name(
            self._variable_names, self._module_aliases, assertion.source, load=True
        )
        comp = self._construct_float_comparator(au.create_ast_constant(assertion.value))
        self._assertion_nodes.append(
            au.create_ast_assert(au.create_ast_compare(left, ast.Eq(), comp))
        )

    def visit_type_name_assertion(self, assertion: ass.TypeNameAssertion) -> None:
        """
        Creates an assertion of form:
        assert f"{type(int_0).__module__}.{type(int_0).__qualname__}" == "builtins.int".

        Args:
            assertion: the assertion that is visited.
        """
        self._assertion_nodes.append(
            ast.Assert(
                test=ast.Compare(
                    left=ast.JoinedStr(
                        values=[
                            ast.FormattedValue(
                                value=ast.Attribute(
                                    value=ast.Call(
                                        func=ast.Name(id="type", ctx=ast.Load()),
                                        args=[
                                            au.create_full_name(
                                                self._variable_names,
                                                self._module_aliases,
                                                assertion.source,
                                                load=True,
                                            )
                                        ],
                                        keywords=[],
                                    ),
                                    attr="__module__",
                                    ctx=ast.Load(),
                                ),
                                conversion=-1,
                            ),
                            ast.Constant(value="."),
                            ast.FormattedValue(
                                value=ast.Attribute(
                                    value=ast.Call(
                                        func=ast.Name(id="type", ctx=ast.Load()),
                                        args=[
                                            au.create_full_name(
                                                self._variable_names,
                                                self._module_aliases,
                                                assertion.source,
                                                load=True,
                                            )
                                        ],
                                        keywords=[],
                                    ),
                                    attr="__qualname__",
                                    ctx=ast.Load(),
                                ),
                                conversion=-1,
                            ),
                        ]
                    ),
                    ops=[ast.Eq()],
                    comparators=[
                        ast.Constant(value=f"{assertion.module}.{assertion.qualname}")
                    ],
                )
            )
        )

    def visit_object_assertion(self, assertion: ass.ObjectAssertion) -> None:
        """
        Creates an assertion of form "assert var0 == value" or "assert var0 is False",
        if the value is a bool.

        Args:
            assertion: the assertion that is visited.
        """
        if isinstance(assertion.object, (bool, type(None))):
            self._assertion_nodes.append(
                self._create_constant_assert(
                    assertion.source, ast.Is(), assertion.object
                )
            )
        else:
            comp = self._create_assertable_object(assertion.object)
            left = au.create_full_name(
                self._variable_names, self._module_aliases, assertion.source, load=True
            )
            self._assertion_nodes.append(
                au.create_ast_assert(au.create_ast_compare(left, ast.Eq(), comp))
            )

    def _create_assertable_object(self, value: Any):
        """Recursively constructs an assertable object. Must handle all cases that
        `pynguin.utils.type_utils.is_assertable` accepts.

        Args:
            value: The value that should be generated.

        Returns:
            An assertion representation of the given value.

        Raises:
            AssertionError: If we encounter an object that we can't construct.
        """
        typ = type(value)
        if tu.is_enum(typ):
            enum_attr = self._construct_enum_attr(value)
            return au.create_ast_attribute(value.name, enum_attr)
        if tu.is_primitive_type(typ) or tu.is_none_type(typ):
            return au.create_ast_constant(value)
        if tu.is_set(typ) or tu.is_list(typ) or tu.is_tuple(typ):
            elements = [self._create_assertable_object(v) for v in value]
            if tu.is_set(typ):
                return au.create_ast_set(elements)
            if tu.is_list(typ):
                return au.create_ast_list(elements)
            return au.create_ast_tuple(elements)
        if tu.is_dict(typ):
            keys = [self._create_assertable_object(v) for v in value]
            values = [self._create_assertable_object(v) for v in value.values()]
            return au.create_ast_dict(keys, values)
        raise AssertionError(f"Cannot create assertion object of type {type(value)}")

    def _create_constant_assert(
        self, var: vr.Reference, operator: ast.cmpop, value: Any
    ) -> ast.Assert:
        left = au.create_full_name(
            self._variable_names, self._module_aliases, var, load=True
        )
        comp = au.create_ast_constant(value)
        return au.create_ast_assert(au.create_ast_compare(left, operator, comp))

    def _construct_float_comparator(self, comp):
        self._common_modules.add("pytest")
        float_precision = config.configuration.test_case_output.float_precision
        func = au.create_ast_attribute("approx", au.create_ast_name("pytest"))
        keywords = [
            au.create_ast_keyword("abs", au.create_ast_constant(float_precision)),
            au.create_ast_keyword("rel", au.create_ast_constant(float_precision)),
        ]
        comp_float = au.create_ast_call(func, [comp], keywords)
        return comp_float

    def _construct_enum_attr(self, value) -> ast.Attribute:
        module = self._module_aliases.get_name(value.__class__.__module__)
        enum_name = value.__class__.__name__
        return au.create_ast_attribute(enum_name, au.create_ast_name(module))

    def visit_collection_length_assertion(
        self, assertion: ass.CollectionLengthAssertion
    ) -> None:
        left = au.create_ast_call(
            au.create_ast_name("len"),
            [
                au.create_full_name(
                    self._variable_names, self._module_aliases, assertion.source, True
                )
            ],
            [],
        )
        self._assertion_nodes.append(
            au.create_ast_assert(
                au.create_ast_compare(
                    left, ast.Eq(), au.create_ast_constant(assertion.length)
                )
            )
        )

    def visit_exception_assertion(self, assertion: ass.ExceptionAssertion) -> None:
        assert (
            not self._has_seen_exception
        ), "Cannot assert multiple exceptions on same statement"
        self._has_seen_exception = True
        self._common_modules.add("pytest")
        assert len(self._assertion_nodes) == 0

        if assertion.module == "builtins":
            # No need to add an import for builtins
            exception_ast_name: ast.Name | ast.Attribute = au.create_ast_name(
                assertion.exception_type_name, store=False
            )
        else:
            exception_ast_name = au.create_ast_attribute(
                attr=assertion.exception_type_name,
                value=au.create_ast_name(
                    self._module_aliases.get_name(assertion.module), store=False
                ),
                store=False,
            )

        # Exception assertions are special, we need to wrap the statement in
        # with pytest.raises(...)
        self._statement_node = ast.With(
            items=[
                ast.withitem(
                    context_expr=ast.Call(
                        func=ast.Attribute(
                            value=ast.Name(id="pytest", ctx=ast.Load()),
                            attr="raises",
                            ctx=ast.Load(),
                        ),
                        args=[exception_ast_name],
                        keywords=[],
                    )
                )
            ],
            body=[self._statement_node],
        )
