"""Testing utils for `Node`."""

import functools
from typing import Callable, TypeVar

from etils import epath

from mlcroissant._src.core.context import Context
from mlcroissant._src.core.context import Issues
from mlcroissant._src.structure_graph.base_node import Node
from mlcroissant._src.structure_graph.nodes.field import Field
from mlcroissant._src.structure_graph.nodes.file_object import FileObject
from mlcroissant._src.structure_graph.nodes.file_set import FileSet
from mlcroissant._src.structure_graph.nodes.record_set import RecordSet


class _EmptyNode(Node):
    def __post_init__(self):
        pass

    @classmethod
    def from_jsonld(cls):
        pass

    def to_json(self):
        pass


def _node_params(**kwargs):
    params = {
        "ctx": Context(folder=epath.Path()),
        "name": "node_name",
        "id": "node_name",
    }
    for key, value in kwargs.items():
        params[key] = value
    return params


T = TypeVar("T")


def create_test_node(cls: type[T], **kwargs) -> T:
    """Utils to easily create new nodes in tests.

    Usage:

    Instead of writing:
    ```python
    node = FileSet(
        id=...,
        issues=...,
        name=...,
        folder=...,
        description="Description"
    )
    ```

    Use:
    ```python
    node = create_test_file_set(description="Description")
    ```
    """
    return cls(**_node_params(**kwargs))


create_test_field: Callable[..., Field] = functools.partial(  # type: ignore  # Force mypy types.
    create_test_node, Field, name="field_name", id="field_name"
)
create_test_file_object: Callable[..., FileObject] = functools.partial(  # type: ignore  # Force mypy types.
    create_test_node, FileObject, name="file_object_name", id="file_object_name"
)
create_test_file_set: Callable[..., FileSet] = functools.partial(  # type: ignore  # Force mypy types.
    create_test_node, FileSet, name="file_set_name", id="file_set_name"
)
create_test_record_set: Callable[..., RecordSet] = functools.partial(  # type: ignore  # Force mypy types.
    create_test_node, RecordSet, name="record_set_name", id="record_set_name"
)


empty_field: Field = create_test_node(Field, name="field_name", id="field_name")
empty_file_object: FileObject = create_test_node(
    FileObject, name="file_object_name", id="file_object_name"
)
empty_file_set: FileSet = create_test_node(
    FileSet, name="file_set_name", id="file_set_name"
)
empty_node: Node = create_test_node(_EmptyNode)
empty_record_set: RecordSet = create_test_node(
    RecordSet, name="record_set_name", id="record_set_name"
)


def assert_contain_error(issues: Issues, error_msg: str):
    """Assert whether one of the errors in issues contains a message."""
    assert any(
        error_msg in error for error in issues.errors
    ), f'Could not find error "{error_msg}" in {list(issues.errors)}'


def assert_contain_warning(issues: Issues, error_msg: str):
    """Assert whether one of the warnings in issues contains a message."""
    assert any(
        error_msg in error for error in issues.warnings
    ), f'Could not find warning "{error_msg}" in {list(issues.warnings)}'
