Coverage for src/lite_agent/agent.py: 53%
309 statements
« prev ^ index » next coverage.py v7.10.5, created at 2025-08-25 22:58 +0900
« prev ^ index » next coverage.py v7.10.5, created at 2025-08-25 22:58 +0900
1import time
2from collections.abc import AsyncGenerator, Callable, Sequence
3from pathlib import Path
4from typing import Any, Optional
6from funcall import Funcall
7from jinja2 import Environment, FileSystemLoader
9from lite_agent.client import BaseLLMClient, LiteLLMClient, ReasoningConfig
10from lite_agent.constants import CompletionMode, ToolName
11from lite_agent.loggers import logger
12from lite_agent.response_handlers import CompletionResponseHandler, ResponsesAPIHandler
13from lite_agent.types import (
14 AgentChunk,
15 AssistantTextContent,
16 AssistantToolCall,
17 AssistantToolCallResult,
18 FunctionCallEvent,
19 FunctionCallOutputEvent,
20 RunnerMessages,
21 ToolCall,
22 message_to_llm_dict,
23 system_message_to_llm_dict,
24)
25from lite_agent.types.messages import NewAssistantMessage, NewSystemMessage, NewUserMessage
27TEMPLATES_DIR = Path(__file__).parent / "templates"
28jinja_env = Environment(loader=FileSystemLoader(str(TEMPLATES_DIR)), autoescape=True)
30HANDOFFS_SOURCE_INSTRUCTIONS_TEMPLATE = jinja_env.get_template("handoffs_source_instructions.xml.j2")
31HANDOFFS_TARGET_INSTRUCTIONS_TEMPLATE = jinja_env.get_template("handoffs_target_instructions.xml.j2")
32WAIT_FOR_USER_INSTRUCTIONS_TEMPLATE = jinja_env.get_template("wait_for_user_instructions.xml.j2")
35class Agent:
36 def __init__(
37 self,
38 *,
39 model: str | BaseLLMClient,
40 name: str,
41 instructions: str,
42 tools: list[Callable] | None = None,
43 handoffs: list["Agent"] | None = None,
44 message_transfer: Callable[[RunnerMessages], RunnerMessages] | None = None,
45 completion_condition: str = "stop",
46 reasoning: ReasoningConfig = None,
47 stop_before_tools: list[str] | list[Callable] | None = None,
48 ) -> None:
49 self.name = name
50 self.instructions = instructions
51 self.reasoning = reasoning
52 # Convert stop_before_functions to function names
53 if stop_before_tools:
54 self.stop_before_functions = set()
55 for func in stop_before_tools:
56 if isinstance(func, str):
57 self.stop_before_functions.add(func)
58 elif callable(func):
59 self.stop_before_functions.add(func.__name__)
60 else:
61 msg = f"stop_before_functions must contain strings or callables, got {type(func)}"
62 raise TypeError(msg)
63 else:
64 self.stop_before_functions = set()
66 if isinstance(model, BaseLLMClient):
67 # If model is a BaseLLMClient instance, use it directly
68 self.client = model
69 else:
70 # Otherwise, create a LitellmClient instance
71 self.client = LiteLLMClient(
72 model=model,
73 reasoning=reasoning,
74 )
75 self.completion_condition = completion_condition
76 self.handoffs = handoffs if handoffs else []
77 self._parent: Agent | None = None
78 self.message_transfer = message_transfer
79 # Initialize Funcall with regular tools
80 self.fc = Funcall(tools)
82 # Add wait_for_user tool if completion condition is "call"
83 if completion_condition == CompletionMode.CALL:
84 self._add_wait_for_user_tool()
86 # Set parent for handoff agents
87 if handoffs:
88 for handoff_agent in handoffs:
89 handoff_agent.parent = self
90 self._add_transfer_tools(handoffs)
92 # Add transfer_to_parent tool if this agent has a parent (for cases where parent is set externally)
93 if self.parent is not None:
94 self.add_transfer_to_parent_tool()
96 @property
97 def parent(self) -> Optional["Agent"]:
98 return self._parent
100 @parent.setter
101 def parent(self, value: Optional["Agent"]) -> None:
102 self._parent = value
103 if value is not None:
104 self.add_transfer_to_parent_tool()
106 def _add_transfer_tools(self, handoffs: list["Agent"]) -> None:
107 """Add transfer function for handoff agents using dynamic tools.
109 Creates a single 'transfer_to_agent' function that accepts a 'name' parameter
110 to specify which agent to transfer the conversation to.
112 Args:
113 handoffs: List of Agent objects that can be transferred to
114 """
115 # Collect all agent names for validation
116 agent_names = [agent.name for agent in handoffs]
118 def transfer_handler(name: str) -> str:
119 """Handler for transfer_to_agent function."""
120 if name in agent_names:
121 return f"Transferring to agent: {name}"
123 available_agents = ", ".join(agent_names)
124 return f"Agent '{name}' not found. Available agents: {available_agents}"
126 # Add single dynamic tool for all transfers
127 self.fc.add_dynamic_tool(
128 name=ToolName.TRANSFER_TO_AGENT,
129 description="Transfer conversation to another agent.",
130 parameters={
131 "name": {
132 "type": "string",
133 "description": "The name of the agent to transfer to",
134 "enum": agent_names,
135 },
136 },
137 required=["name"],
138 handler=transfer_handler,
139 )
141 def add_transfer_to_parent_tool(self) -> None:
142 """Add transfer_to_parent function for agents that have a parent.
144 This tool allows the agent to transfer back to its parent when:
145 - The current task is completed
146 - The agent cannot solve the current problem
147 - Escalation to a higher level is needed
148 """
150 def transfer_to_parent_handler() -> str:
151 """Handler for transfer_to_parent function."""
152 if self.parent:
153 return f"Transferring back to parent agent: {self.parent.name}"
154 return "No parent agent found"
156 # Add dynamic tool for parent transfer
157 self.fc.add_dynamic_tool(
158 name=ToolName.TRANSFER_TO_PARENT,
159 description="Transfer conversation back to parent agent when current task is completed or cannot be solved by current agent",
160 parameters={},
161 required=[],
162 handler=transfer_to_parent_handler,
163 )
165 def add_handoff(self, agent: "Agent") -> None:
166 """Add a handoff agent after initialization.
168 This method allows adding handoff agents dynamically after the agent
169 has been constructed. It properly sets up parent-child relationships
170 and updates the transfer tools.
172 Args:
173 agent: The agent to add as a handoff target
174 """
175 # Add to handoffs list if not already present
176 if agent not in self.handoffs:
177 self.handoffs.append(agent)
179 # Set parent relationship
180 agent.parent = self
182 # Add transfer_to_parent tool to the handoff agent
183 agent.add_transfer_to_parent_tool()
185 # Remove existing transfer tool if it exists and recreate with all agents
186 try:
187 # Try to remove the existing transfer tool
188 if hasattr(self.fc, "remove_dynamic_tool"):
189 self.fc.remove_dynamic_tool(ToolName.TRANSFER_TO_AGENT)
190 except Exception as e:
191 # If removal fails, log and continue anyway
192 logger.debug(f"Failed to remove existing transfer tool: {e}")
194 # Regenerate transfer tools to include the new agent
195 self._add_transfer_tools(self.handoffs)
197 def prepare_completion_messages(self, messages: RunnerMessages) -> list[dict]:
198 """Prepare messages for completions API (with conversion)."""
199 converted_messages = self._convert_responses_to_completions_format(messages)
200 instructions = self.instructions
201 if self.handoffs:
202 instructions = HANDOFFS_SOURCE_INSTRUCTIONS_TEMPLATE.render(extra_instructions=None) + "\n\n" + instructions
203 if self.parent:
204 instructions = HANDOFFS_TARGET_INSTRUCTIONS_TEMPLATE.render(extra_instructions=None) + "\n\n" + instructions
205 if self.completion_condition == "call":
206 instructions = WAIT_FOR_USER_INSTRUCTIONS_TEMPLATE.render(extra_instructions=None) + "\n\n" + instructions
207 return [
208 system_message_to_llm_dict(
209 NewSystemMessage(
210 content=f"You are {self.name}. {instructions}",
211 ),
212 ),
213 *converted_messages,
214 ]
216 def prepare_responses_messages(self, messages: RunnerMessages) -> list[dict[str, Any]]:
217 """Prepare messages for responses API (no conversion, just add system message if needed)."""
218 instructions = self.instructions
219 if self.handoffs:
220 instructions = HANDOFFS_SOURCE_INSTRUCTIONS_TEMPLATE.render(extra_instructions=None) + "\n\n" + instructions
221 if self.parent:
222 instructions = HANDOFFS_TARGET_INSTRUCTIONS_TEMPLATE.render(extra_instructions=None) + "\n\n" + instructions
223 if self.completion_condition == "call":
224 instructions = WAIT_FOR_USER_INSTRUCTIONS_TEMPLATE.render(extra_instructions=None) + "\n\n" + instructions
225 res: list[dict[str, Any]] = [
226 {
227 "role": "system",
228 "content": f"You are {self.name}. {instructions}",
229 },
230 ]
231 for message in messages:
232 if isinstance(message, NewAssistantMessage):
233 for item in message.content:
234 if isinstance(item, AssistantTextContent):
235 res.append(
236 {
237 "role": "assistant",
238 "content": item.text,
239 },
240 )
241 elif isinstance(item, AssistantToolCall):
242 res.append(
243 {
244 "type": "function_call",
245 "call_id": item.call_id,
246 "name": item.name,
247 "arguments": item.arguments,
248 },
249 )
250 elif isinstance(item, AssistantToolCallResult):
251 res.append(
252 {
253 "type": "function_call_output",
254 "call_id": item.call_id,
255 "output": item.output,
256 },
257 )
258 elif isinstance(message, NewSystemMessage):
259 res.append(
260 {
261 "role": "system",
262 "content": message.content,
263 },
264 )
265 elif isinstance(message, NewUserMessage):
266 contents = []
267 for item in message.content:
268 match item.type:
269 case "text":
270 contents.append(
271 {
272 "type": "input_text",
273 "text": item.text,
274 },
275 )
276 case "image":
277 contents.append(
278 {
279 "type": "input_image",
280 "image_url": item.image_url,
281 },
282 )
283 case "file":
284 contents.append(
285 {
286 "type": "input_file",
287 "file_id": item.file_id,
288 "file_name": item.file_name,
289 },
290 )
291 res.append(
292 {
293 "role": message.role,
294 "content": contents,
295 },
296 )
297 return res
299 async def completion(
300 self,
301 messages: RunnerMessages,
302 record_to_file: Path | None = None,
303 reasoning: ReasoningConfig = None,
304 *,
305 streaming: bool = True,
306 ) -> AsyncGenerator[AgentChunk, None]:
307 # Apply message transfer callback if provided - always use legacy format for LLM compatibility
308 processed_messages = messages
309 if self.message_transfer:
310 logger.debug(f"Applying message transfer callback for agent {self.name}")
311 processed_messages = self.message_transfer(messages)
313 # For completions API, use prepare_completion_messages
314 self.message_histories = self.prepare_completion_messages(processed_messages)
316 tools = self.fc.get_tools(target="completion")
317 resp = await self.client.completion(
318 messages=self.message_histories,
319 tools=tools,
320 tool_choice="auto", # TODO: make this configurable
321 reasoning=reasoning,
322 streaming=streaming,
323 )
325 # Use response handler for unified processing
326 handler = CompletionResponseHandler()
327 return handler.handle(resp, streaming=streaming, record_to=record_to_file)
329 async def responses(
330 self,
331 messages: RunnerMessages,
332 record_to_file: Path | None = None,
333 reasoning: ReasoningConfig = None,
334 *,
335 streaming: bool = True,
336 ) -> AsyncGenerator[AgentChunk, None]:
337 # Apply message transfer callback if provided - always use legacy format for LLM compatibility
338 processed_messages = messages
339 if self.message_transfer:
340 logger.debug(f"Applying message transfer callback for agent {self.name}")
341 processed_messages = self.message_transfer(messages)
343 # For responses API, use prepare_responses_messages (no conversion)
344 self.message_histories = self.prepare_responses_messages(processed_messages)
345 tools = self.fc.get_tools()
346 resp = await self.client.responses(
347 messages=self.message_histories,
348 tools=tools,
349 tool_choice="auto", # TODO: make this configurable
350 reasoning=reasoning,
351 streaming=streaming,
352 )
353 # Use response handler for unified processing
354 handler = ResponsesAPIHandler()
355 return handler.handle(resp, streaming=streaming, record_to=record_to_file)
357 async def list_require_confirm_tools(self, tool_calls: Sequence[ToolCall] | None) -> Sequence[ToolCall]:
358 if not tool_calls:
359 return []
360 results = []
361 for tool_call in tool_calls:
362 function_name = tool_call.function.name
364 # Check if function is in dynamic stop_before_functions list
365 if function_name in self.stop_before_functions:
366 logger.debug('Tool call "%s" requires confirmation (stop_before_functions)', tool_call.id)
367 results.append(tool_call)
368 continue
370 # Check decorator-based require_confirmation
371 tool_func = self.fc.function_registry.get(function_name)
372 if not tool_func:
373 logger.warning("Tool function %s not found in registry", function_name)
374 continue
375 tool_meta = self.fc.get_tool_meta(function_name)
376 if tool_meta["require_confirm"]:
377 logger.debug('Tool call "%s" requires confirmation (decorator)', tool_call.id)
378 results.append(tool_call)
379 return results
381 async def handle_tool_calls(self, tool_calls: Sequence[ToolCall] | None, context: Any | None = None) -> AsyncGenerator[FunctionCallEvent | FunctionCallOutputEvent, None]: # noqa: ANN401
382 if not tool_calls:
383 return
384 if tool_calls:
385 for tool_call in tool_calls:
386 tool_func = self.fc.function_registry.get(tool_call.function.name)
387 if not tool_func:
388 logger.warning("Tool function %s not found in registry", tool_call.function.name)
389 continue
391 for tool_call in tool_calls:
392 yield FunctionCallEvent(
393 call_id=tool_call.id,
394 name=tool_call.function.name,
395 arguments=tool_call.function.arguments or "",
396 )
397 start_time = time.time()
398 try:
399 content = await self.fc.call_function_async(tool_call.function.name, tool_call.function.arguments or "", context)
400 end_time = time.time()
401 execution_time_ms = int((end_time - start_time) * 1000)
402 yield FunctionCallOutputEvent(
403 tool_call_id=tool_call.id,
404 name=tool_call.function.name,
405 content=str(content),
406 execution_time_ms=execution_time_ms,
407 )
408 except Exception as e:
409 logger.exception("Tool call %s failed", tool_call.id)
410 end_time = time.time()
411 execution_time_ms = int((end_time - start_time) * 1000)
412 yield FunctionCallOutputEvent(
413 tool_call_id=tool_call.id,
414 name=tool_call.function.name,
415 content=str(e),
416 execution_time_ms=execution_time_ms,
417 )
419 def _convert_responses_to_completions_format(self, messages: RunnerMessages) -> list[dict]:
420 """Convert messages from responses API format to completions API format."""
421 converted_messages = []
422 i = 0
424 while i < len(messages):
425 message = messages[i]
426 message_dict = message_to_llm_dict(message) if isinstance(message, (NewUserMessage, NewSystemMessage, NewAssistantMessage)) else message
428 message_type = message_dict.get("type")
429 role = message_dict.get("role")
431 if role == "assistant":
432 # Extract tool_calls from content if present
433 tool_calls = []
434 content = message_dict.get("content", [])
436 # Handle both string and array content
437 if isinstance(content, list):
438 # Extract tool_calls from content array and filter out non-text content
439 filtered_content = []
440 for item in content:
441 if isinstance(item, dict):
442 if item.get("type") == "tool_call":
443 tool_call = {
444 "id": item.get("call_id", ""),
445 "type": "function",
446 "function": {
447 "name": item.get("name", ""),
448 "arguments": item.get("arguments", "{}"),
449 },
450 "index": len(tool_calls),
451 }
452 tool_calls.append(tool_call)
453 elif item.get("type") == "text":
454 filtered_content.append(item)
455 # Skip tool_call_result - they should be handled by separate function_call_output messages
457 # Update content to only include text items
458 if filtered_content:
459 message_dict = message_dict.copy()
460 message_dict["content"] = filtered_content
461 elif tool_calls:
462 # If we have tool_calls but no text content, set content to None per OpenAI API spec
463 message_dict = message_dict.copy()
464 message_dict["content"] = None
466 # Look ahead for function_call messages (legacy support)
467 j = i + 1
468 while j < len(messages):
469 next_message = messages[j]
470 next_dict = message_to_llm_dict(next_message) if isinstance(next_message, (NewUserMessage, NewSystemMessage, NewAssistantMessage)) else next_message
472 if next_dict.get("type") == "function_call":
473 tool_call = {
474 "id": next_dict["call_id"], # type: ignore
475 "type": "function",
476 "function": {
477 "name": next_dict["name"], # type: ignore
478 "arguments": next_dict["arguments"], # type: ignore
479 },
480 "index": len(tool_calls),
481 }
482 tool_calls.append(tool_call)
483 j += 1
484 else:
485 break
487 # Create assistant message with tool_calls if any
488 assistant_msg = message_dict.copy()
489 if tool_calls:
490 assistant_msg["tool_calls"] = tool_calls # type: ignore
492 # Convert content format for OpenAI API compatibility
493 content = assistant_msg.get("content", [])
494 if isinstance(content, list):
495 # Extract text content and convert to string using list comprehension
496 text_parts = [item.get("text", "") for item in content if isinstance(item, dict) and item.get("type") == "text"]
497 assistant_msg["content"] = " ".join(text_parts) if text_parts else None
499 converted_messages.append(assistant_msg)
500 i = j # Skip the function_call messages we've processed
502 elif message_type == "function_call_output":
503 # Convert to tool message
504 converted_messages.append(
505 {
506 "role": "tool",
507 "tool_call_id": message_dict["call_id"], # type: ignore
508 "content": message_dict["output"], # type: ignore
509 },
510 )
511 i += 1
513 elif message_type == "function_call":
514 # This should have been processed with the assistant message
515 # Skip it if we encounter it standalone
516 i += 1
518 else:
519 # Regular message (user, system)
520 converted_msg = message_dict.copy()
522 # Handle new Response API format for user messages
523 content = message_dict.get("content")
524 if role == "user" and isinstance(content, list):
525 converted_msg["content"] = self._convert_user_content_to_completions_format(content) # type: ignore
527 converted_messages.append(converted_msg)
528 i += 1
530 return converted_messages
532 def _convert_user_content_to_completions_format(self, content: list) -> list:
533 """Convert user message content from Response API format to Completion API format."""
534 # Handle the case where content might not actually be a list due to test mocking
535 if type(content) is not list: # Use type() instead of isinstance() to avoid test mocking issues
536 return content
538 converted_content = []
539 for item in content:
540 # Convert Pydantic objects to dict first
541 if hasattr(item, "model_dump"):
542 item_dict = item.model_dump()
543 elif hasattr(item, "dict"): # For older Pydantic versions
544 item_dict = item.dict()
545 elif isinstance(item, dict):
546 item_dict = item
547 else:
548 # Handle non-dict items (shouldn't happen, but just in case)
549 converted_content.append(item)
550 continue
552 item_type = item_dict.get("type")
553 if item_type in ["input_text", "text"]:
554 # Convert ResponseInputText or new text format to completion API format
555 converted_content.append(
556 {
557 "type": "text",
558 "text": item_dict["text"],
559 },
560 )
561 elif item_type in ["input_image", "image"]:
562 # Convert ResponseInputImage to completion API format
563 if item_dict.get("file_id"):
564 msg = "File ID input is not supported for Completion API"
565 raise ValueError(msg)
567 if not item_dict.get("image_url"):
568 msg = "ResponseInputImage must have either file_id or image_url"
569 raise ValueError(msg)
571 # Build image_url object with detail inside
572 image_data = {"url": item_dict["image_url"]}
573 detail = item_dict.get("detail", "auto")
574 if detail: # Include detail if provided
575 image_data["detail"] = detail
577 converted_content.append(
578 {
579 "type": "image_url",
580 "image_url": image_data,
581 },
582 )
583 else:
584 # Keep existing format (text, image_url)
585 converted_content.append(item_dict)
587 return converted_content
589 def set_message_transfer(self, message_transfer: Callable[[RunnerMessages], RunnerMessages] | None) -> None:
590 """Set or update the message transfer callback function.
592 Args:
593 message_transfer: A callback function that takes RunnerMessages as input
594 and returns RunnerMessages as output. This function will be
595 called before making API calls to allow preprocessing of messages.
596 """
597 self.message_transfer = message_transfer
599 def _add_wait_for_user_tool(self) -> None:
600 """Add wait_for_user tool for agents with completion_condition='call'.
602 This tool allows the agent to signal when it has completed its task.
603 """
605 def wait_for_user_handler() -> str:
606 """Handler for wait_for_user function."""
607 return "Waiting for user input."
609 # Add dynamic tool for task completion
610 self.fc.add_dynamic_tool(
611 name=ToolName.WAIT_FOR_USER,
612 description="Call this function when you have completed your assigned task or need more information from the user.",
613 parameters={},
614 required=[],
615 handler=wait_for_user_handler,
616 )
618 def set_stop_before_functions(self, functions: list[str] | list[Callable]) -> None:
619 """Set the list of functions that require confirmation before execution.
621 Args:
622 functions: List of function names (str) or callable objects
623 """
624 self.stop_before_functions = set()
625 for func in functions:
626 if isinstance(func, str):
627 self.stop_before_functions.add(func)
628 elif callable(func):
629 self.stop_before_functions.add(func.__name__)
630 else:
631 msg = f"stop_before_functions must contain strings or callables, got {type(func)}"
632 raise TypeError(msg)
633 logger.debug(f"Set stop_before_functions to: {self.stop_before_functions}")
635 def add_stop_before_function(self, function: str | Callable) -> None:
636 """Add a function to the stop_before_functions list.
638 Args:
639 function: Function name (str) or callable object to add
640 """
641 if isinstance(function, str):
642 function_name = function
643 elif callable(function):
644 function_name = function.__name__
645 else:
646 msg = f"function must be a string or callable, got {type(function)}"
647 raise TypeError(msg)
649 self.stop_before_functions.add(function_name)
650 logger.debug(f"Added '{function_name}' to stop_before_functions")
652 def remove_stop_before_function(self, function: str | Callable) -> None:
653 """Remove a function from the stop_before_functions list.
655 Args:
656 function: Function name (str) or callable object to remove
657 """
658 if isinstance(function, str):
659 function_name = function
660 elif callable(function):
661 function_name = function.__name__
662 else:
663 msg = f"function must be a string or callable, got {type(function)}"
664 raise TypeError(msg)
666 self.stop_before_functions.discard(function_name)
667 logger.debug(f"Removed '{function_name}' from stop_before_functions")
669 def clear_stop_before_functions(self) -> None:
670 """Clear all function names from the stop_before_functions list."""
671 self.stop_before_functions.clear()
672 logger.debug("Cleared all stop_before_functions")
674 def get_stop_before_functions(self) -> set[str]:
675 """Get the current set of function names that require confirmation.
677 Returns:
678 Set of function names
679 """
680 return self.stop_before_functions.copy()