import pytest

from .classification_model import ClassificationModel
from .memoryset import LabeledMemoryLookup
from .telemetry import FeedbackCategory, LabelPrediction


def test_get_prediction(model: ClassificationModel):
    predictions = model.predict(["Do you love soup?", "Are cats cute?"])
    prediction_with_telemetry = LabelPrediction.get(predictions[0].prediction_id)
    assert prediction_with_telemetry is not None
    assert prediction_with_telemetry.label == 0
    assert prediction_with_telemetry.input_value == "Do you love soup?"


def test_get_predictions(model: ClassificationModel):
    predictions = model.predict(["Do you love soup?", "Are cats cute?"])
    prediction_with_telemetry = LabelPrediction.get([predictions[0].prediction_id, predictions[1].prediction_id])
    assert len(prediction_with_telemetry) == 2
    assert prediction_with_telemetry[0].label == 0
    assert prediction_with_telemetry[0].input_value == "Do you love soup?"
    assert prediction_with_telemetry[1].label == 1


def test_get_predictions_with_expected_label_match(model: ClassificationModel):
    model.predict(["Do you love soup?", "Are cats cute?"], expected_labels=[0, 0], tags={"expected_label_match"})
    model.predict("no expectations", tags={"expected_label_match"})
    assert len(model.predictions(tag="expected_label_match")) == 3
    assert len(model.predictions(expected_label_match=True, tag="expected_label_match")) == 1
    assert len(model.predictions(expected_label_match=False, tag="expected_label_match")) == 1


def test_get_prediction_memory_lookups(model: ClassificationModel):
    prediction = model.predict("Do you love soup?")
    assert isinstance(prediction.memory_lookups, list)
    assert len(prediction.memory_lookups) > 0
    assert all(isinstance(lookup, LabeledMemoryLookup) for lookup in prediction.memory_lookups)


def test_record_feedback(model: ClassificationModel):
    prediction = model.predict("Do you love soup?")
    assert "correct" not in prediction.feedback
    prediction.record_feedback(category="correct", value=prediction.label == 0)
    assert prediction.feedback["correct"] is True


def test_record_feedback_with_invalid_value(model: ClassificationModel):
    with pytest.raises(ValueError, match=r"Invalid input.*"):
        model.predict("Do you love soup?").record_feedback(category="correct", value="not a bool")  # type: ignore


def test_record_feedback_with_inconsistent_value_for_category(model: ClassificationModel):
    model.predict("Do you love soup?").record_feedback(category="correct", value=True)
    with pytest.raises(ValueError, match=r"Invalid input.*"):
        model.predict("Do you love soup?").record_feedback(category="correct", value=-1.0)


def test_delete_feedback(model: ClassificationModel):
    prediction = model.predict("Do you love soup?")
    prediction.record_feedback(category="test_delete", value=True)
    assert "test_delete" in prediction.feedback
    prediction.delete_feedback("test_delete")
    assert "test_delete" not in prediction.feedback


def test_list_feedback_categories(model: ClassificationModel):
    prediction = model.predict("Do you love soup?")
    prediction.record_feedback(category="correct", value=True)
    prediction.record_feedback(category="confidence", value=0.8)
    categories = FeedbackCategory.all()
    assert len(categories) >= 2
    assert any(c.name == "correct" and c.value_type == bool for c in categories)
    assert any(c.name == "confidence" and c.value_type == float for c in categories)


def test_drop_feedback_category(model: ClassificationModel):
    prediction = model.predict("Do you love soup?")
    prediction.record_feedback(category="test_category", value=True)
    assert any(c.name == "test_category" for c in FeedbackCategory.all())
    FeedbackCategory.drop("test_category")
    assert not any(c.name == "test_category" for c in FeedbackCategory.all())
    prediction.refresh()
    assert "test_category" not in prediction.feedback


def test_update_prediction(model: ClassificationModel):
    prediction = model.predict("Do you love soup?")
    assert prediction.expected_label is None
    assert prediction.tags == set()
    # update expected label
    prediction.update(expected_label=1)
    assert prediction.expected_label == 1

    # update tags
    prediction.update(tags={"test_tag1", "test_tag2"})
    assert prediction.tags == {"test_tag1", "test_tag2"}

    # update both
    prediction.update(expected_label=0, tags={"new_tag"})
    assert prediction.expected_label == 0
    assert prediction.tags == {"new_tag"}

    # remove expected label
    prediction.update(expected_label=None)
    assert prediction.expected_label is None

    # remove tags
    prediction.update(tags=None)
    assert prediction.tags == set()
