from collections.abc import AsyncGenerator, Sequence
from dataclasses import dataclass

from flexai.capability import Capability
from flexai.message import Message, MessageContent


@dataclass
class TruncateMessages(Capability):
    """Truncate the input messages to the LLM to a maximum number."""

    # The maximum number of messages to keep.
    max_messages: int

    async def modify_messages(
        self, messages: Sequence[Message]
    ) -> AsyncGenerator[MessageContent | Sequence[Message], None]:
        yield messages[-self.max_messages :]
