Coverage for src/lite_agent/runner.py: 80%
432 statements
« prev ^ index » next coverage.py v7.10.5, created at 2025-08-25 22:58 +0900
« prev ^ index » next coverage.py v7.10.5, created at 2025-08-25 22:58 +0900
1import json
2import warnings
3from collections.abc import AsyncGenerator, Sequence
4from datetime import datetime, timedelta, timezone
5from os import PathLike
6from pathlib import Path
7from typing import Any, Literal, cast
9from lite_agent.agent import Agent
10from lite_agent.constants import CompletionMode, StreamIncludes, ToolName
11from lite_agent.loggers import logger
12from lite_agent.types import (
13 AgentChunk,
14 AgentChunkType,
15 AssistantMessageMeta,
16 AssistantTextContent,
17 AssistantToolCall,
18 AssistantToolCallResult,
19 FlexibleInputMessage,
20 FlexibleRunnerMessage,
21 MessageUsage,
22 NewAssistantMessage,
23 NewMessage,
24 NewSystemMessage,
25 NewUserMessage,
26 ToolCall,
27 ToolCallFunction,
28 UserInput,
29 UserTextContent,
30)
31from lite_agent.types.events import AssistantMessageEvent, FunctionCallOutputEvent
32from lite_agent.utils.message_builder import MessageBuilder
35class Runner:
36 def __init__(self, agent: Agent, api: Literal["completion", "responses"] = "responses", *, streaming: bool = True) -> None:
37 self.agent = agent
38 self.messages: list[FlexibleRunnerMessage] = []
39 self.api = api
40 self.streaming = streaming
41 self._current_assistant_message: NewAssistantMessage | None = None
42 self.usage = MessageUsage(input_tokens=0, output_tokens=0, total_tokens=0)
44 def _start_assistant_message(self, content: str = "", meta: AssistantMessageMeta | None = None) -> None:
45 """Start a new assistant message."""
46 self._current_assistant_message = NewAssistantMessage(
47 content=[AssistantTextContent(text=content)],
48 meta=meta or AssistantMessageMeta(),
49 )
51 def _ensure_current_assistant_message(self) -> NewAssistantMessage:
52 """Ensure current assistant message exists and return it."""
53 if self._current_assistant_message is None:
54 self._start_assistant_message()
55 if self._current_assistant_message is None:
56 msg = "Failed to create current assistant message"
57 raise RuntimeError(msg)
58 return self._current_assistant_message
60 def _add_to_current_assistant_message(self, content_item: AssistantTextContent | AssistantToolCall | AssistantToolCallResult) -> None:
61 """Add content to the current assistant message."""
62 self._ensure_current_assistant_message().content.append(content_item)
64 def _add_text_content_to_current_assistant_message(self, delta: str) -> None:
65 """Add text delta to the current assistant message's text content."""
66 message = self._ensure_current_assistant_message()
67 # Find the first text content item and append the delta
68 for content_item in message.content:
69 if content_item.type == "text":
70 content_item.text += delta
71 return
72 # If no text content found, add new text content
73 message.content.append(AssistantTextContent(text=delta))
75 def _finalize_assistant_message(self) -> None:
76 """Finalize the current assistant message and add it to messages."""
77 if self._current_assistant_message is not None:
78 self.messages.append(self._current_assistant_message)
79 self._current_assistant_message = None
81 def _add_tool_call_result(self, call_id: str, output: str, execution_time_ms: int | None = None) -> None:
82 """Add a tool call result to the last assistant message, or create a new one if needed."""
83 result = AssistantToolCallResult(
84 call_id=call_id,
85 output=output,
86 execution_time_ms=execution_time_ms,
87 )
89 if self.messages and isinstance(self.messages[-1], NewAssistantMessage):
90 # Add to existing assistant message
91 last_message = cast("NewAssistantMessage", self.messages[-1])
92 last_message.content.append(result)
93 else:
94 # Create new assistant message with just the tool result
95 assistant_message = NewAssistantMessage(content=[result])
96 self.messages.append(assistant_message)
98 # For completion API compatibility, create a separate assistant message
99 # Note: In the new architecture, we store everything as NewMessage format
100 # The conversion to completion format happens when sending to LLM
102 def _normalize_includes(self, includes: Sequence[AgentChunkType] | None) -> Sequence[AgentChunkType]:
103 """Normalize includes parameter to default if None."""
104 return includes if includes is not None else StreamIncludes.DEFAULT_INCLUDES
106 def _normalize_record_path(self, record_to: PathLike | str | None) -> Path | None:
107 """Normalize record_to parameter to Path object if provided."""
108 return Path(record_to) if record_to else None
110 async def _handle_tool_calls(self, tool_calls: "Sequence[ToolCall] | None", includes: Sequence[AgentChunkType], context: "Any | None" = None) -> AsyncGenerator[AgentChunk, None]: # noqa: ANN401
111 """Handle tool calls and yield appropriate chunks."""
112 if not tool_calls:
113 return
115 # Check for transfer_to_agent calls first
116 transfer_calls = [tc for tc in tool_calls if tc.function.name == ToolName.TRANSFER_TO_AGENT]
117 if transfer_calls:
118 # Handle all transfer calls but only execute the first one
119 for i, tool_call in enumerate(transfer_calls):
120 if i == 0:
121 # Execute the first transfer
122 call_id, output = await self._handle_agent_transfer(tool_call)
123 # Generate function_call_output event if in includes
124 if "function_call_output" in includes:
125 yield FunctionCallOutputEvent(
126 tool_call_id=call_id,
127 name=tool_call.function.name,
128 content=output,
129 execution_time_ms=0, # Transfer operations are typically fast
130 )
131 else:
132 # Add response for additional transfer calls without executing them
133 output = "Transfer already executed by previous call"
134 self._add_tool_call_result(
135 call_id=tool_call.id,
136 output=output,
137 )
138 # Generate function_call_output event if in includes
139 if "function_call_output" in includes:
140 yield FunctionCallOutputEvent(
141 tool_call_id=tool_call.id,
142 name=tool_call.function.name,
143 content=output,
144 execution_time_ms=0,
145 )
146 return # Stop processing other tool calls after transfer
148 return_parent_calls = [tc for tc in tool_calls if tc.function.name == ToolName.TRANSFER_TO_PARENT]
149 if return_parent_calls:
150 # Handle multiple transfer_to_parent calls (only execute the first one)
151 for i, tool_call in enumerate(return_parent_calls):
152 if i == 0:
153 # Execute the first transfer
154 call_id, output = await self._handle_parent_transfer(tool_call)
155 # Generate function_call_output event if in includes
156 if "function_call_output" in includes:
157 yield FunctionCallOutputEvent(
158 tool_call_id=call_id,
159 name=tool_call.function.name,
160 content=output,
161 execution_time_ms=0, # Transfer operations are typically fast
162 )
163 else:
164 # Add response for additional transfer calls without executing them
165 output = "Transfer already executed by previous call"
166 self._add_tool_call_result(
167 call_id=tool_call.id,
168 output=output,
169 )
170 # Generate function_call_output event if in includes
171 if "function_call_output" in includes:
172 yield FunctionCallOutputEvent(
173 tool_call_id=tool_call.id,
174 name=tool_call.function.name,
175 content=output,
176 execution_time_ms=0,
177 )
178 return # Stop processing other tool calls after transfer
180 async for tool_call_chunk in self.agent.handle_tool_calls(tool_calls, context=context):
181 # if tool_call_chunk.type == "function_call" and tool_call_chunk.type in includes:
182 # yield tool_call_chunk
183 if tool_call_chunk.type == "function_call_output":
184 if tool_call_chunk.type in includes:
185 yield tool_call_chunk
186 # Add tool result to the last assistant message
187 if self.messages and isinstance(self.messages[-1], NewAssistantMessage):
188 tool_result = AssistantToolCallResult(
189 call_id=tool_call_chunk.tool_call_id,
190 output=tool_call_chunk.content,
191 execution_time_ms=tool_call_chunk.execution_time_ms,
192 )
193 last_message = cast("NewAssistantMessage", self.messages[-1])
194 last_message.content.append(tool_result)
196 # Note: For completion API compatibility, the conversion happens when sending to LLM
198 async def _collect_all_chunks(self, stream: AsyncGenerator[AgentChunk, None]) -> list[AgentChunk]:
199 """Collect all chunks from an async generator into a list."""
200 return [chunk async for chunk in stream]
202 def run(
203 self,
204 user_input: UserInput | None = None,
205 max_steps: int = 20,
206 includes: Sequence[AgentChunkType] | None = None,
207 context: "Any | None" = None, # noqa: ANN401
208 record_to: PathLike | str | None = None,
209 agent_kwargs: dict[str, Any] | None = None,
210 ) -> AsyncGenerator[AgentChunk, None]:
211 """Run the agent and return a RunResponse object that can be asynchronously iterated for each chunk.
213 If user_input is None, the method will continue execution from the current state,
214 equivalent to calling the continue methods.
215 """
216 logger.debug(f"Runner.run called with streaming={self.streaming}, api={self.api}")
217 includes = self._normalize_includes(includes)
219 # If no user input provided, use continue logic
220 if user_input is None:
221 logger.debug("No user input provided, using continue logic")
222 return self._run_continue_stream(max_steps, includes, self._normalize_record_path(record_to), context)
224 # Cancel any pending tool calls before processing new user input
225 # and yield cancellation events if they should be included
226 cancellation_events = self._cancel_pending_tool_calls()
228 # We need to handle this differently since run() is not async
229 # Store cancellation events to be yielded by _run
230 self._pending_cancellation_events = cancellation_events
232 # Process user input
233 match user_input:
234 case str():
235 self.messages.append(NewUserMessage(content=[UserTextContent(text=user_input)]))
236 case list() | tuple():
237 # Handle sequence of messages
238 for message in user_input:
239 self.append_message(message)
240 case _:
241 # Handle single message (BaseModel, TypedDict, or dict)
242 self.append_message(user_input) # type: ignore[arg-type]
243 logger.debug("Messages prepared, calling _run")
244 return self._run(max_steps, includes, self._normalize_record_path(record_to), context=context, agent_kwargs=agent_kwargs)
246 async def _run(
247 self,
248 max_steps: int,
249 includes: Sequence[AgentChunkType],
250 record_to: Path | None = None,
251 context: Any | None = None, # noqa: ANN401
252 agent_kwargs: dict[str, Any] | None = None,
253 ) -> AsyncGenerator[AgentChunk, None]:
254 """Run the agent and return a RunResponse object that can be asynchronously iterated for each chunk."""
255 logger.debug(f"Running agent with messages: {self.messages}")
257 # First, yield any pending cancellation events
258 if hasattr(self, "_pending_cancellation_events"):
259 for cancellation_event in self._pending_cancellation_events:
260 if "function_call_output" in includes:
261 yield cancellation_event
262 # Clear the pending events after yielding
263 delattr(self, "_pending_cancellation_events")
265 steps = 0
266 finish_reason = None
268 # Determine completion condition based on agent configuration
269 completion_condition = getattr(self.agent, "completion_condition", CompletionMode.STOP)
271 def is_finish() -> bool:
272 if completion_condition == CompletionMode.CALL:
273 # Check if wait_for_user was called in the last assistant message
274 if self.messages and isinstance(self.messages[-1], NewAssistantMessage):
275 last_message = self.messages[-1]
276 for content_item in last_message.content:
277 if isinstance(content_item, AssistantToolCallResult) and self._get_tool_call_name_by_id(content_item.call_id) == ToolName.WAIT_FOR_USER:
278 return True
279 return False
280 return finish_reason == CompletionMode.STOP
282 while not is_finish() and steps < max_steps:
283 logger.debug(f"Step {steps}: finish_reason={finish_reason}, is_finish()={is_finish()}")
284 # Convert to legacy format only when needed for LLM communication
285 # This allows us to keep the new format internally but ensures compatibility
286 # Extract agent kwargs for reasoning configuration
287 reasoning = None
288 if agent_kwargs:
289 reasoning = agent_kwargs.get("reasoning")
291 logger.debug(f"Using API: {self.api}, streaming: {self.streaming}")
292 match self.api:
293 case "completion":
294 logger.debug("Calling agent.completion")
295 resp = await self.agent.completion(
296 self.messages,
297 record_to_file=record_to,
298 reasoning=reasoning,
299 streaming=self.streaming,
300 )
301 case "responses":
302 logger.debug("Calling agent.responses")
303 resp = await self.agent.responses(
304 self.messages,
305 record_to_file=record_to,
306 reasoning=reasoning,
307 streaming=self.streaming,
308 )
309 case _:
310 msg = f"Unknown API type: {self.api}"
311 raise ValueError(msg)
312 logger.debug(f"Received response from agent: {type(resp)}")
313 async for chunk in resp:
314 match chunk.type:
315 case "assistant_message":
316 # Start or update assistant message in new format
317 # If we already have a current assistant message, just update its metadata
318 if self._current_assistant_message is not None:
319 # Preserve all existing metadata and only update specific fields
320 original_meta = self._current_assistant_message.meta
321 original_meta.sent_at = chunk.message.meta.sent_at
322 if hasattr(chunk.message.meta, "latency_ms"):
323 original_meta.latency_ms = chunk.message.meta.latency_ms
324 if hasattr(chunk.message.meta, "output_time_ms"):
325 original_meta.total_time_ms = chunk.message.meta.output_time_ms
326 # Preserve other metadata fields like model, usage, etc.
327 for attr in ["model", "usage", "input_tokens", "output_tokens"]:
328 if hasattr(chunk.message.meta, attr):
329 setattr(original_meta, attr, getattr(chunk.message.meta, attr))
330 else:
331 # For non-streaming mode, directly use the complete message from the response handler
332 self._current_assistant_message = chunk.message
334 # If model is None, try to get it from agent client
335 if self._current_assistant_message is not None and self._current_assistant_message.meta.model is None and hasattr(self.agent.client, "model"):
336 self._current_assistant_message.meta.model = self.agent.client.model
337 # Only yield assistant_message chunk if it's in includes and has content
338 if chunk.type in includes and self._current_assistant_message is not None:
339 # Create a new chunk with the current assistant message content
340 updated_chunk = AssistantMessageEvent(
341 message=self._current_assistant_message,
342 )
343 yield updated_chunk
344 case "content_delta":
345 # Accumulate text content to current assistant message
346 self._add_text_content_to_current_assistant_message(chunk.delta)
347 # Always yield content_delta chunk if it's in includes
348 if chunk.type in includes:
349 yield chunk
350 case "function_call":
351 # Add tool call to current assistant message
352 # Keep arguments as string for compatibility with funcall library
353 tool_call = AssistantToolCall(
354 call_id=chunk.call_id,
355 name=chunk.name,
356 arguments=chunk.arguments or "{}",
357 )
358 self._add_to_current_assistant_message(tool_call)
359 # Always yield function_call chunk if it's in includes
360 if chunk.type in includes:
361 yield chunk
362 case "usage":
363 # Update the current or last assistant message with usage data and output_time_ms
364 usage_time = datetime.now(timezone.utc)
366 # Always accumulate usage in runner first
367 self.usage.input_tokens = (self.usage.input_tokens or 0) + (chunk.usage.input_tokens or 0)
368 self.usage.output_tokens = (self.usage.output_tokens or 0) + (chunk.usage.output_tokens or 0)
369 self.usage.total_tokens = (self.usage.total_tokens or 0) + (chunk.usage.input_tokens or 0) + (chunk.usage.output_tokens or 0)
371 # Try to find the assistant message to update
372 target_message = None
374 # First check if we have a current assistant message
375 if self._current_assistant_message is not None:
376 target_message = self._current_assistant_message
377 else:
378 # Otherwise, look for the last assistant message in the list
379 for i in range(len(self.messages) - 1, -1, -1):
380 current_message = self.messages[i]
381 if isinstance(current_message, NewAssistantMessage):
382 target_message = current_message
383 break
385 # Update the target message with usage information
386 if target_message is not None:
387 if target_message.meta.usage is None:
388 target_message.meta.usage = MessageUsage()
389 target_message.meta.usage.input_tokens = chunk.usage.input_tokens
390 target_message.meta.usage.output_tokens = chunk.usage.output_tokens
391 target_message.meta.usage.total_tokens = (chunk.usage.input_tokens or 0) + (chunk.usage.output_tokens or 0)
393 # Calculate output_time_ms if latency_ms is available
394 if target_message.meta.latency_ms is not None:
395 # We need to calculate from first output to usage time
396 # We'll calculate: usage_time - (sent_at - latency_ms)
397 # This gives us the time from first output to usage completion
398 # sent_at is when the message was completed, so sent_at - latency_ms approximates first output time
399 first_output_time_approx = target_message.meta.sent_at - timedelta(milliseconds=target_message.meta.latency_ms)
400 output_time_ms = int((usage_time - first_output_time_approx).total_seconds() * 1000)
401 target_message.meta.total_time_ms = max(0, output_time_ms)
402 # Always yield usage chunk if it's in includes
403 if chunk.type in includes:
404 yield chunk
405 case _ if chunk.type in includes:
406 yield chunk
408 # Finalize assistant message so it can be found in pending function calls
409 self._finalize_assistant_message()
411 # Check for pending tool calls after processing current assistant message
412 pending_tool_calls = self._find_pending_tool_calls()
413 logger.debug(f"Found {len(pending_tool_calls)} pending tool calls")
414 if pending_tool_calls:
415 # Convert to ToolCall format for existing handler
416 tool_calls = self._convert_tool_calls_to_tool_calls(pending_tool_calls)
417 require_confirm_tools = await self.agent.list_require_confirm_tools(tool_calls)
418 if require_confirm_tools:
419 return
420 async for tool_chunk in self._handle_tool_calls(tool_calls, includes, context=context):
421 yield tool_chunk
422 finish_reason = "tool_calls"
423 else:
424 finish_reason = CompletionMode.STOP
425 steps += 1
427 async def has_require_confirm_tools(self):
428 pending_tool_calls = self._find_pending_tool_calls()
429 if not pending_tool_calls:
430 return False
431 tool_calls = self._convert_tool_calls_to_tool_calls(pending_tool_calls)
432 require_confirm_tools = await self.agent.list_require_confirm_tools(tool_calls)
433 return bool(require_confirm_tools)
435 async def run_continue_until_complete(
436 self,
437 max_steps: int = 20,
438 includes: list[AgentChunkType] | None = None,
439 record_to: PathLike | str | None = None,
440 ) -> list[AgentChunk]:
441 """Deprecated: Use run_until_complete(None) instead."""
442 warnings.warn(
443 "run_continue_until_complete is deprecated. Use run_until_complete(None) instead.",
444 DeprecationWarning,
445 stacklevel=2,
446 )
447 resp = self.run_continue_stream(max_steps, includes, record_to=record_to)
448 return await self._collect_all_chunks(resp)
450 def run_continue_stream(
451 self,
452 max_steps: int = 20,
453 includes: list[AgentChunkType] | None = None,
454 record_to: PathLike | str | None = None,
455 context: "Any | None" = None, # noqa: ANN401
456 ) -> AsyncGenerator[AgentChunk, None]:
457 """Deprecated: Use run(None) instead."""
458 warnings.warn(
459 "run_continue_stream is deprecated. Use run(None) instead.",
460 DeprecationWarning,
461 stacklevel=2,
462 )
463 return self._run_continue_stream(max_steps, includes, record_to=record_to, context=context)
465 async def _run_continue_stream(
466 self,
467 max_steps: int = 20,
468 includes: Sequence[AgentChunkType] | None = None,
469 record_to: PathLike | str | None = None,
470 context: "Any | None" = None, # noqa: ANN401
471 ) -> AsyncGenerator[AgentChunk, None]:
472 """Continue running the agent and return a RunResponse object that can be asynchronously iterated for each chunk."""
473 includes = self._normalize_includes(includes)
475 # Find pending tool calls in responses format
476 pending_tool_calls = self._find_pending_tool_calls()
477 if pending_tool_calls:
478 # Convert to ToolCall format for existing handler
479 tool_calls = self._convert_tool_calls_to_tool_calls(pending_tool_calls)
480 async for tool_chunk in self._handle_tool_calls(tool_calls, includes, context=context):
481 yield tool_chunk
482 async for chunk in self._run(max_steps, includes, self._normalize_record_path(record_to)):
483 if chunk.type in includes:
484 yield chunk
485 else:
486 # Check if there are any messages and what the last message is
487 if not self.messages:
488 msg = "Cannot continue running without a valid last message from the assistant."
489 raise ValueError(msg)
491 resp = self._run(max_steps=max_steps, includes=includes, record_to=self._normalize_record_path(record_to), context=context)
492 async for chunk in resp:
493 yield chunk
495 async def run_until_complete(
496 self,
497 user_input: UserInput | None = None,
498 max_steps: int = 20,
499 includes: list[AgentChunkType] | None = None,
500 record_to: PathLike | str | None = None,
501 ) -> list[AgentChunk]:
502 """Run the agent until it completes and return the final message."""
503 resp = self.run(user_input, max_steps, includes, record_to=record_to)
504 return await self._collect_all_chunks(resp)
506 def _analyze_last_assistant_message(self) -> tuple[list[AssistantToolCall], dict[str, str]]:
507 """Analyze the last assistant message and return pending tool calls and tool call map."""
508 if not self.messages or not isinstance(self.messages[-1], NewAssistantMessage):
509 return [], {}
511 tool_calls = {}
512 tool_results = set()
513 tool_call_names = {}
515 last_message = self.messages[-1]
516 for content_item in last_message.content:
517 if isinstance(content_item, AssistantToolCall):
518 tool_calls[content_item.call_id] = content_item
519 tool_call_names[content_item.call_id] = content_item.name
520 elif isinstance(content_item, AssistantToolCallResult):
521 tool_results.add(content_item.call_id)
523 # Return pending tool calls and tool call names map
524 pending_calls = [call for call_id, call in tool_calls.items() if call_id not in tool_results]
525 return pending_calls, tool_call_names
527 def _find_pending_tool_calls(self) -> list[AssistantToolCall]:
528 """Find tool calls that don't have corresponding results yet."""
529 pending_calls, _ = self._analyze_last_assistant_message()
530 return pending_calls
532 def _get_tool_call_name_by_id(self, call_id: str) -> str | None:
533 """Get the tool name for a given call_id from the last assistant message."""
534 _, tool_call_names = self._analyze_last_assistant_message()
535 return tool_call_names.get(call_id)
537 def _cancel_pending_tool_calls(self) -> list[FunctionCallOutputEvent]:
538 """Cancel all pending tool calls by adding cancellation results.
540 Returns:
541 List of FunctionCallOutputEvent for each cancelled tool call
542 """
543 pending_tool_calls = self._find_pending_tool_calls()
544 if not pending_tool_calls:
545 return []
547 logger.debug(f"Cancelling {len(pending_tool_calls)} pending tool calls due to new user input")
549 cancellation_events = []
550 for tool_call in pending_tool_calls:
551 output = "Operation cancelled by user - new input provided"
552 self._add_tool_call_result(
553 call_id=tool_call.call_id,
554 output=output,
555 execution_time_ms=0,
556 )
558 # Create cancellation event
559 cancellation_event = FunctionCallOutputEvent(
560 tool_call_id=tool_call.call_id,
561 name=tool_call.name,
562 content=output,
563 execution_time_ms=0,
564 )
565 cancellation_events.append(cancellation_event)
567 return cancellation_events
569 def _convert_tool_calls_to_tool_calls(self, tool_calls: list[AssistantToolCall]) -> list[ToolCall]:
570 """Convert AssistantToolCall objects to ToolCall objects for compatibility."""
571 return [
572 ToolCall(
573 id=tc.call_id,
574 type="function",
575 function=ToolCallFunction(
576 name=tc.name,
577 arguments=tc.arguments if isinstance(tc.arguments, str) else str(tc.arguments),
578 ),
579 index=i,
580 )
581 for i, tc in enumerate(tool_calls)
582 ]
584 def set_chat_history(self, messages: Sequence[FlexibleInputMessage], root_agent: Agent | None = None) -> None:
585 """Set the entire chat history and track the current agent based on function calls.
587 This method analyzes the message history to determine which agent should be active
588 based on transfer_to_agent and transfer_to_parent function calls.
590 Args:
591 messages: List of messages to set as the chat history
592 root_agent: The root agent to use if no transfers are found. If None, uses self.agent
593 """
594 # Clear current messages
595 self.messages.clear()
597 # Set initial agent
598 current_agent = root_agent if root_agent is not None else self.agent
600 # Add each message and track agent transfers
601 for input_message in messages:
602 # Store length before adding to get the added message
603 prev_length = len(self.messages)
604 self.append_message(input_message)
606 # Track transfers using the converted message (now in self.messages)
607 if len(self.messages) > prev_length:
608 converted_message = self.messages[-1] # Get the last added message
609 current_agent = self._track_agent_transfer_in_message(converted_message, current_agent)
611 # Set the current agent based on the tracked transfers
612 self.agent = current_agent
613 logger.info(f"Chat history set with {len(self.messages)} messages. Current agent: {self.agent.name}")
615 def get_messages(self) -> list[NewMessage]:
616 """Get the messages as NewMessage objects.
618 Only returns NewMessage objects, filtering out any dict or other legacy formats.
619 """
620 return [msg for msg in self.messages if isinstance(msg, NewMessage)]
622 def get_dict_messages(self) -> list[dict[str, Any]]:
623 """Get the messages in JSONL format."""
624 result = []
625 for msg in self.messages:
626 if hasattr(msg, "model_dump"):
627 result.append(msg.model_dump(mode="json"))
628 elif isinstance(msg, dict):
629 result.append(msg)
630 else:
631 # Fallback for any other message types
632 result.append(dict(msg))
633 return result
635 def add_user_message(self, text: str) -> None:
636 """Convenience method to add a user text message."""
637 message = NewUserMessage(content=[UserTextContent(text=text)])
638 self.append_message(message)
640 def add_assistant_message(self, text: str) -> None:
641 """Convenience method to add an assistant text message."""
642 message = NewAssistantMessage(content=[AssistantTextContent(text=text)])
643 self.append_message(message)
645 def add_system_message(self, content: str) -> None:
646 """Convenience method to add a system message."""
647 message = NewSystemMessage(content=content)
648 self.append_message(message)
650 def _track_agent_transfer_in_message(self, message: FlexibleRunnerMessage, current_agent: Agent) -> Agent:
651 """Track agent transfers in a single message.
653 Args:
654 message: The message to analyze for transfers
655 current_agent: The currently active agent
657 Returns:
658 The agent that should be active after processing this message
659 """
660 if isinstance(message, NewAssistantMessage):
661 return self._track_transfer_from_new_assistant_message(message, current_agent)
663 return current_agent
665 def _track_transfer_from_new_assistant_message(self, message: NewAssistantMessage, current_agent: Agent) -> Agent:
666 """Track transfers from NewAssistantMessage objects."""
667 for content_item in message.content:
668 if content_item.type == "tool_call":
669 if content_item.name == ToolName.TRANSFER_TO_AGENT:
670 arguments = content_item.arguments if isinstance(content_item.arguments, str) else str(content_item.arguments)
671 return self._handle_transfer_to_agent_tracking(arguments, current_agent)
672 if content_item.name == ToolName.TRANSFER_TO_PARENT:
673 return self._handle_transfer_to_parent_tracking(current_agent)
674 return current_agent
676 def _handle_transfer_to_agent_tracking(self, arguments: str | dict, current_agent: Agent) -> Agent:
677 """Handle transfer_to_agent function call tracking."""
678 try:
679 args_dict = json.loads(arguments) if isinstance(arguments, str) else arguments
681 target_agent_name = args_dict.get("name")
682 if target_agent_name:
683 target_agent = self._find_agent_by_name(current_agent, target_agent_name)
684 if target_agent:
685 logger.debug(f"History tracking: Transferring from {current_agent.name} to {target_agent_name}")
686 return target_agent
688 logger.warning(f"Target agent '{target_agent_name}' not found in handoffs during history setup")
689 except (json.JSONDecodeError, KeyError, TypeError) as e:
690 logger.warning(f"Failed to parse transfer_to_agent arguments during history setup: {e}")
692 return current_agent
694 def _handle_transfer_to_parent_tracking(self, current_agent: Agent) -> Agent:
695 """Handle transfer_to_parent function call tracking."""
696 if current_agent.parent:
697 logger.debug(f"History tracking: Transferring from {current_agent.name} back to parent {current_agent.parent.name}")
698 return current_agent.parent
700 logger.warning(f"Agent {current_agent.name} has no parent to transfer back to during history setup")
701 return current_agent
703 def _find_agent_by_name(self, root_agent: Agent, target_name: str) -> Agent | None:
704 """Find an agent by name in the handoffs tree starting from root_agent.
706 Args:
707 root_agent: The root agent to start searching from
708 target_name: The name of the agent to find
710 Returns:
711 The agent if found, None otherwise
712 """
713 # Check direct handoffs from current agent
714 if root_agent.handoffs:
715 for agent in root_agent.handoffs:
716 if agent.name == target_name:
717 return agent
719 # If not found in direct handoffs, check if we need to look in parent's handoffs
720 # This handles cases where agents can transfer to siblings
721 current = root_agent
722 while current.parent is not None:
723 current = current.parent
724 if current.handoffs:
725 for agent in current.handoffs:
726 if agent.name == target_name:
727 return agent
729 return None
731 def append_message(self, message: FlexibleInputMessage) -> None:
732 """Append a message to the conversation history.
734 Accepts both NewMessage format and dict format (which will be converted internally).
735 """
736 if isinstance(message, NewMessage):
737 self.messages.append(message)
738 elif isinstance(message, dict):
739 # Convert dict to NewMessage using MessageBuilder
740 role = message.get("role", "").lower()
741 if role == "user":
742 converted_message = MessageBuilder.build_user_message_from_dict(message)
743 elif role == "assistant":
744 converted_message = MessageBuilder.build_assistant_message_from_dict(message)
745 elif role == "system":
746 converted_message = MessageBuilder.build_system_message_from_dict(message)
747 else:
748 msg = f"Unsupported message role: {role}. Must be 'user', 'assistant', or 'system'."
749 raise ValueError(msg)
751 self.messages.append(converted_message)
752 else:
753 msg = f"Unsupported message type: {type(message)}. Supports NewMessage types and dict."
754 raise TypeError(msg)
756 async def _handle_agent_transfer(self, tool_call: ToolCall) -> tuple[str, str]:
757 """Handle agent transfer when transfer_to_agent tool is called.
759 Args:
760 tool_call: The transfer_to_agent tool call
762 Returns:
763 Tuple of (call_id, output) for the tool call result
764 """
766 # Parse the arguments to get the target agent name
767 try:
768 arguments = json.loads(tool_call.function.arguments or "{}")
769 target_agent_name = arguments.get("name")
770 except (json.JSONDecodeError, KeyError):
771 logger.error("Failed to parse transfer_to_agent arguments: %s", tool_call.function.arguments)
772 output = "Failed to parse transfer arguments"
773 # Add error result to messages
774 self._add_tool_call_result(
775 call_id=tool_call.id,
776 output=output,
777 )
778 return tool_call.id, output
780 if not target_agent_name:
781 logger.error("No target agent name provided in transfer_to_agent call")
782 output = "No target agent name provided"
783 # Add error result to messages
784 self._add_tool_call_result(
785 call_id=tool_call.id,
786 output=output,
787 )
788 return tool_call.id, output
790 # Find the target agent in handoffs
791 if not self.agent.handoffs:
792 logger.error("Current agent has no handoffs configured")
793 output = "Current agent has no handoffs configured"
794 # Add error result to messages
795 self._add_tool_call_result(
796 call_id=tool_call.id,
797 output=output,
798 )
799 return tool_call.id, output
801 target_agent = None
802 for agent in self.agent.handoffs:
803 if agent.name == target_agent_name:
804 target_agent = agent
805 break
807 if not target_agent:
808 logger.error("Target agent '%s' not found in handoffs", target_agent_name)
809 output = f"Target agent '{target_agent_name}' not found in handoffs"
810 # Add error result to messages
811 self._add_tool_call_result(
812 call_id=tool_call.id,
813 output=output,
814 )
815 return tool_call.id, output
817 # Execute the transfer tool call to get the result
818 try:
819 result = await self.agent.fc.call_function_async(
820 tool_call.function.name,
821 tool_call.function.arguments or "",
822 )
824 output = str(result)
825 # Add the tool call result to messages
826 self._add_tool_call_result(
827 call_id=tool_call.id,
828 output=output,
829 )
831 # Switch to the target agent
832 logger.info("Transferring conversation from %s to %s", self.agent.name, target_agent_name)
833 self.agent = target_agent
835 except Exception as e:
836 logger.exception("Failed to execute transfer_to_agent tool call")
837 output = f"Transfer failed: {e!s}"
838 # Add error result to messages
839 self._add_tool_call_result(
840 call_id=tool_call.id,
841 output=output,
842 )
843 return tool_call.id, output
844 else:
845 return tool_call.id, output
847 async def _handle_parent_transfer(self, tool_call: ToolCall) -> tuple[str, str]:
848 """Handle parent transfer when transfer_to_parent tool is called.
850 Args:
851 tool_call: The transfer_to_parent tool call
853 Returns:
854 Tuple of (call_id, output) for the tool call result
855 """
857 # Check if current agent has a parent
858 if not self.agent.parent:
859 logger.error("Current agent has no parent to transfer back to.")
860 output = "Current agent has no parent to transfer back to"
861 # Add error result to messages
862 self._add_tool_call_result(
863 call_id=tool_call.id,
864 output=output,
865 )
866 return tool_call.id, output
868 # Execute the transfer tool call to get the result
869 try:
870 result = await self.agent.fc.call_function_async(
871 tool_call.function.name,
872 tool_call.function.arguments or "",
873 )
875 output = str(result)
876 # Add the tool call result to messages
877 self._add_tool_call_result(
878 call_id=tool_call.id,
879 output=output,
880 )
882 # Switch to the parent agent
883 logger.info("Transferring conversation from %s back to parent %s", self.agent.name, self.agent.parent.name)
884 self.agent = self.agent.parent
886 except Exception as e:
887 logger.exception("Failed to execute transfer_to_parent tool call")
888 output = f"Transfer to parent failed: {e!s}"
889 # Add error result to messages
890 self._add_tool_call_result(
891 call_id=tool_call.id,
892 output=output,
893 )
894 return tool_call.id, output
895 else:
896 return tool_call.id, output