Coverage for src/lite_agent/runner.py: 80%

432 statements  

« prev     ^ index     » next       coverage.py v7.10.5, created at 2025-08-25 22:58 +0900

1import json 

2import warnings 

3from collections.abc import AsyncGenerator, Sequence 

4from datetime import datetime, timedelta, timezone 

5from os import PathLike 

6from pathlib import Path 

7from typing import Any, Literal, cast 

8 

9from lite_agent.agent import Agent 

10from lite_agent.constants import CompletionMode, StreamIncludes, ToolName 

11from lite_agent.loggers import logger 

12from lite_agent.types import ( 

13 AgentChunk, 

14 AgentChunkType, 

15 AssistantMessageMeta, 

16 AssistantTextContent, 

17 AssistantToolCall, 

18 AssistantToolCallResult, 

19 FlexibleInputMessage, 

20 FlexibleRunnerMessage, 

21 MessageUsage, 

22 NewAssistantMessage, 

23 NewMessage, 

24 NewSystemMessage, 

25 NewUserMessage, 

26 ToolCall, 

27 ToolCallFunction, 

28 UserInput, 

29 UserTextContent, 

30) 

31from lite_agent.types.events import AssistantMessageEvent, FunctionCallOutputEvent 

32from lite_agent.utils.message_builder import MessageBuilder 

33 

34 

35class Runner: 

36 def __init__(self, agent: Agent, api: Literal["completion", "responses"] = "responses", *, streaming: bool = True) -> None: 

37 self.agent = agent 

38 self.messages: list[FlexibleRunnerMessage] = [] 

39 self.api = api 

40 self.streaming = streaming 

41 self._current_assistant_message: NewAssistantMessage | None = None 

42 self.usage = MessageUsage(input_tokens=0, output_tokens=0, total_tokens=0) 

43 

44 def _start_assistant_message(self, content: str = "", meta: AssistantMessageMeta | None = None) -> None: 

45 """Start a new assistant message.""" 

46 self._current_assistant_message = NewAssistantMessage( 

47 content=[AssistantTextContent(text=content)], 

48 meta=meta or AssistantMessageMeta(), 

49 ) 

50 

51 def _ensure_current_assistant_message(self) -> NewAssistantMessage: 

52 """Ensure current assistant message exists and return it.""" 

53 if self._current_assistant_message is None: 

54 self._start_assistant_message() 

55 if self._current_assistant_message is None: 

56 msg = "Failed to create current assistant message" 

57 raise RuntimeError(msg) 

58 return self._current_assistant_message 

59 

60 def _add_to_current_assistant_message(self, content_item: AssistantTextContent | AssistantToolCall | AssistantToolCallResult) -> None: 

61 """Add content to the current assistant message.""" 

62 self._ensure_current_assistant_message().content.append(content_item) 

63 

64 def _add_text_content_to_current_assistant_message(self, delta: str) -> None: 

65 """Add text delta to the current assistant message's text content.""" 

66 message = self._ensure_current_assistant_message() 

67 # Find the first text content item and append the delta 

68 for content_item in message.content: 

69 if content_item.type == "text": 

70 content_item.text += delta 

71 return 

72 # If no text content found, add new text content 

73 message.content.append(AssistantTextContent(text=delta)) 

74 

75 def _finalize_assistant_message(self) -> None: 

76 """Finalize the current assistant message and add it to messages.""" 

77 if self._current_assistant_message is not None: 

78 self.messages.append(self._current_assistant_message) 

79 self._current_assistant_message = None 

80 

81 def _add_tool_call_result(self, call_id: str, output: str, execution_time_ms: int | None = None) -> None: 

82 """Add a tool call result to the last assistant message, or create a new one if needed.""" 

83 result = AssistantToolCallResult( 

84 call_id=call_id, 

85 output=output, 

86 execution_time_ms=execution_time_ms, 

87 ) 

88 

89 if self.messages and isinstance(self.messages[-1], NewAssistantMessage): 

90 # Add to existing assistant message 

91 last_message = cast("NewAssistantMessage", self.messages[-1]) 

92 last_message.content.append(result) 

93 else: 

94 # Create new assistant message with just the tool result 

95 assistant_message = NewAssistantMessage(content=[result]) 

96 self.messages.append(assistant_message) 

97 

98 # For completion API compatibility, create a separate assistant message 

99 # Note: In the new architecture, we store everything as NewMessage format 

100 # The conversion to completion format happens when sending to LLM 

101 

102 def _normalize_includes(self, includes: Sequence[AgentChunkType] | None) -> Sequence[AgentChunkType]: 

103 """Normalize includes parameter to default if None.""" 

104 return includes if includes is not None else StreamIncludes.DEFAULT_INCLUDES 

105 

106 def _normalize_record_path(self, record_to: PathLike | str | None) -> Path | None: 

107 """Normalize record_to parameter to Path object if provided.""" 

108 return Path(record_to) if record_to else None 

109 

110 async def _handle_tool_calls(self, tool_calls: "Sequence[ToolCall] | None", includes: Sequence[AgentChunkType], context: "Any | None" = None) -> AsyncGenerator[AgentChunk, None]: # noqa: ANN401 

111 """Handle tool calls and yield appropriate chunks.""" 

112 if not tool_calls: 

113 return 

114 

115 # Check for transfer_to_agent calls first 

116 transfer_calls = [tc for tc in tool_calls if tc.function.name == ToolName.TRANSFER_TO_AGENT] 

117 if transfer_calls: 

118 # Handle all transfer calls but only execute the first one 

119 for i, tool_call in enumerate(transfer_calls): 

120 if i == 0: 

121 # Execute the first transfer 

122 call_id, output = await self._handle_agent_transfer(tool_call) 

123 # Generate function_call_output event if in includes 

124 if "function_call_output" in includes: 

125 yield FunctionCallOutputEvent( 

126 tool_call_id=call_id, 

127 name=tool_call.function.name, 

128 content=output, 

129 execution_time_ms=0, # Transfer operations are typically fast 

130 ) 

131 else: 

132 # Add response for additional transfer calls without executing them 

133 output = "Transfer already executed by previous call" 

134 self._add_tool_call_result( 

135 call_id=tool_call.id, 

136 output=output, 

137 ) 

138 # Generate function_call_output event if in includes 

139 if "function_call_output" in includes: 

140 yield FunctionCallOutputEvent( 

141 tool_call_id=tool_call.id, 

142 name=tool_call.function.name, 

143 content=output, 

144 execution_time_ms=0, 

145 ) 

146 return # Stop processing other tool calls after transfer 

147 

148 return_parent_calls = [tc for tc in tool_calls if tc.function.name == ToolName.TRANSFER_TO_PARENT] 

149 if return_parent_calls: 

150 # Handle multiple transfer_to_parent calls (only execute the first one) 

151 for i, tool_call in enumerate(return_parent_calls): 

152 if i == 0: 

153 # Execute the first transfer 

154 call_id, output = await self._handle_parent_transfer(tool_call) 

155 # Generate function_call_output event if in includes 

156 if "function_call_output" in includes: 

157 yield FunctionCallOutputEvent( 

158 tool_call_id=call_id, 

159 name=tool_call.function.name, 

160 content=output, 

161 execution_time_ms=0, # Transfer operations are typically fast 

162 ) 

163 else: 

164 # Add response for additional transfer calls without executing them 

165 output = "Transfer already executed by previous call" 

166 self._add_tool_call_result( 

167 call_id=tool_call.id, 

168 output=output, 

169 ) 

170 # Generate function_call_output event if in includes 

171 if "function_call_output" in includes: 

172 yield FunctionCallOutputEvent( 

173 tool_call_id=tool_call.id, 

174 name=tool_call.function.name, 

175 content=output, 

176 execution_time_ms=0, 

177 ) 

178 return # Stop processing other tool calls after transfer 

179 

180 async for tool_call_chunk in self.agent.handle_tool_calls(tool_calls, context=context): 

181 # if tool_call_chunk.type == "function_call" and tool_call_chunk.type in includes: 

182 # yield tool_call_chunk 

183 if tool_call_chunk.type == "function_call_output": 

184 if tool_call_chunk.type in includes: 

185 yield tool_call_chunk 

186 # Add tool result to the last assistant message 

187 if self.messages and isinstance(self.messages[-1], NewAssistantMessage): 

188 tool_result = AssistantToolCallResult( 

189 call_id=tool_call_chunk.tool_call_id, 

190 output=tool_call_chunk.content, 

191 execution_time_ms=tool_call_chunk.execution_time_ms, 

192 ) 

193 last_message = cast("NewAssistantMessage", self.messages[-1]) 

194 last_message.content.append(tool_result) 

195 

196 # Note: For completion API compatibility, the conversion happens when sending to LLM 

197 

198 async def _collect_all_chunks(self, stream: AsyncGenerator[AgentChunk, None]) -> list[AgentChunk]: 

199 """Collect all chunks from an async generator into a list.""" 

200 return [chunk async for chunk in stream] 

201 

202 def run( 

203 self, 

204 user_input: UserInput | None = None, 

205 max_steps: int = 20, 

206 includes: Sequence[AgentChunkType] | None = None, 

207 context: "Any | None" = None, # noqa: ANN401 

208 record_to: PathLike | str | None = None, 

209 agent_kwargs: dict[str, Any] | None = None, 

210 ) -> AsyncGenerator[AgentChunk, None]: 

211 """Run the agent and return a RunResponse object that can be asynchronously iterated for each chunk. 

212 

213 If user_input is None, the method will continue execution from the current state, 

214 equivalent to calling the continue methods. 

215 """ 

216 logger.debug(f"Runner.run called with streaming={self.streaming}, api={self.api}") 

217 includes = self._normalize_includes(includes) 

218 

219 # If no user input provided, use continue logic 

220 if user_input is None: 

221 logger.debug("No user input provided, using continue logic") 

222 return self._run_continue_stream(max_steps, includes, self._normalize_record_path(record_to), context) 

223 

224 # Cancel any pending tool calls before processing new user input 

225 # and yield cancellation events if they should be included 

226 cancellation_events = self._cancel_pending_tool_calls() 

227 

228 # We need to handle this differently since run() is not async 

229 # Store cancellation events to be yielded by _run 

230 self._pending_cancellation_events = cancellation_events 

231 

232 # Process user input 

233 match user_input: 

234 case str(): 

235 self.messages.append(NewUserMessage(content=[UserTextContent(text=user_input)])) 

236 case list() | tuple(): 

237 # Handle sequence of messages 

238 for message in user_input: 

239 self.append_message(message) 

240 case _: 

241 # Handle single message (BaseModel, TypedDict, or dict) 

242 self.append_message(user_input) # type: ignore[arg-type] 

243 logger.debug("Messages prepared, calling _run") 

244 return self._run(max_steps, includes, self._normalize_record_path(record_to), context=context, agent_kwargs=agent_kwargs) 

245 

246 async def _run( 

247 self, 

248 max_steps: int, 

249 includes: Sequence[AgentChunkType], 

250 record_to: Path | None = None, 

251 context: Any | None = None, # noqa: ANN401 

252 agent_kwargs: dict[str, Any] | None = None, 

253 ) -> AsyncGenerator[AgentChunk, None]: 

254 """Run the agent and return a RunResponse object that can be asynchronously iterated for each chunk.""" 

255 logger.debug(f"Running agent with messages: {self.messages}") 

256 

257 # First, yield any pending cancellation events 

258 if hasattr(self, "_pending_cancellation_events"): 

259 for cancellation_event in self._pending_cancellation_events: 

260 if "function_call_output" in includes: 

261 yield cancellation_event 

262 # Clear the pending events after yielding 

263 delattr(self, "_pending_cancellation_events") 

264 

265 steps = 0 

266 finish_reason = None 

267 

268 # Determine completion condition based on agent configuration 

269 completion_condition = getattr(self.agent, "completion_condition", CompletionMode.STOP) 

270 

271 def is_finish() -> bool: 

272 if completion_condition == CompletionMode.CALL: 

273 # Check if wait_for_user was called in the last assistant message 

274 if self.messages and isinstance(self.messages[-1], NewAssistantMessage): 

275 last_message = self.messages[-1] 

276 for content_item in last_message.content: 

277 if isinstance(content_item, AssistantToolCallResult) and self._get_tool_call_name_by_id(content_item.call_id) == ToolName.WAIT_FOR_USER: 

278 return True 

279 return False 

280 return finish_reason == CompletionMode.STOP 

281 

282 while not is_finish() and steps < max_steps: 

283 logger.debug(f"Step {steps}: finish_reason={finish_reason}, is_finish()={is_finish()}") 

284 # Convert to legacy format only when needed for LLM communication 

285 # This allows us to keep the new format internally but ensures compatibility 

286 # Extract agent kwargs for reasoning configuration 

287 reasoning = None 

288 if agent_kwargs: 

289 reasoning = agent_kwargs.get("reasoning") 

290 

291 logger.debug(f"Using API: {self.api}, streaming: {self.streaming}") 

292 match self.api: 

293 case "completion": 

294 logger.debug("Calling agent.completion") 

295 resp = await self.agent.completion( 

296 self.messages, 

297 record_to_file=record_to, 

298 reasoning=reasoning, 

299 streaming=self.streaming, 

300 ) 

301 case "responses": 

302 logger.debug("Calling agent.responses") 

303 resp = await self.agent.responses( 

304 self.messages, 

305 record_to_file=record_to, 

306 reasoning=reasoning, 

307 streaming=self.streaming, 

308 ) 

309 case _: 

310 msg = f"Unknown API type: {self.api}" 

311 raise ValueError(msg) 

312 logger.debug(f"Received response from agent: {type(resp)}") 

313 async for chunk in resp: 

314 match chunk.type: 

315 case "assistant_message": 

316 # Start or update assistant message in new format 

317 # If we already have a current assistant message, just update its metadata 

318 if self._current_assistant_message is not None: 

319 # Preserve all existing metadata and only update specific fields 

320 original_meta = self._current_assistant_message.meta 

321 original_meta.sent_at = chunk.message.meta.sent_at 

322 if hasattr(chunk.message.meta, "latency_ms"): 

323 original_meta.latency_ms = chunk.message.meta.latency_ms 

324 if hasattr(chunk.message.meta, "output_time_ms"): 

325 original_meta.total_time_ms = chunk.message.meta.output_time_ms 

326 # Preserve other metadata fields like model, usage, etc. 

327 for attr in ["model", "usage", "input_tokens", "output_tokens"]: 

328 if hasattr(chunk.message.meta, attr): 

329 setattr(original_meta, attr, getattr(chunk.message.meta, attr)) 

330 else: 

331 # For non-streaming mode, directly use the complete message from the response handler 

332 self._current_assistant_message = chunk.message 

333 

334 # If model is None, try to get it from agent client 

335 if self._current_assistant_message is not None and self._current_assistant_message.meta.model is None and hasattr(self.agent.client, "model"): 

336 self._current_assistant_message.meta.model = self.agent.client.model 

337 # Only yield assistant_message chunk if it's in includes and has content 

338 if chunk.type in includes and self._current_assistant_message is not None: 

339 # Create a new chunk with the current assistant message content 

340 updated_chunk = AssistantMessageEvent( 

341 message=self._current_assistant_message, 

342 ) 

343 yield updated_chunk 

344 case "content_delta": 

345 # Accumulate text content to current assistant message 

346 self._add_text_content_to_current_assistant_message(chunk.delta) 

347 # Always yield content_delta chunk if it's in includes 

348 if chunk.type in includes: 

349 yield chunk 

350 case "function_call": 

351 # Add tool call to current assistant message 

352 # Keep arguments as string for compatibility with funcall library 

353 tool_call = AssistantToolCall( 

354 call_id=chunk.call_id, 

355 name=chunk.name, 

356 arguments=chunk.arguments or "{}", 

357 ) 

358 self._add_to_current_assistant_message(tool_call) 

359 # Always yield function_call chunk if it's in includes 

360 if chunk.type in includes: 

361 yield chunk 

362 case "usage": 

363 # Update the current or last assistant message with usage data and output_time_ms 

364 usage_time = datetime.now(timezone.utc) 

365 

366 # Always accumulate usage in runner first 

367 self.usage.input_tokens = (self.usage.input_tokens or 0) + (chunk.usage.input_tokens or 0) 

368 self.usage.output_tokens = (self.usage.output_tokens or 0) + (chunk.usage.output_tokens or 0) 

369 self.usage.total_tokens = (self.usage.total_tokens or 0) + (chunk.usage.input_tokens or 0) + (chunk.usage.output_tokens or 0) 

370 

371 # Try to find the assistant message to update 

372 target_message = None 

373 

374 # First check if we have a current assistant message 

375 if self._current_assistant_message is not None: 

376 target_message = self._current_assistant_message 

377 else: 

378 # Otherwise, look for the last assistant message in the list 

379 for i in range(len(self.messages) - 1, -1, -1): 

380 current_message = self.messages[i] 

381 if isinstance(current_message, NewAssistantMessage): 

382 target_message = current_message 

383 break 

384 

385 # Update the target message with usage information 

386 if target_message is not None: 

387 if target_message.meta.usage is None: 

388 target_message.meta.usage = MessageUsage() 

389 target_message.meta.usage.input_tokens = chunk.usage.input_tokens 

390 target_message.meta.usage.output_tokens = chunk.usage.output_tokens 

391 target_message.meta.usage.total_tokens = (chunk.usage.input_tokens or 0) + (chunk.usage.output_tokens or 0) 

392 

393 # Calculate output_time_ms if latency_ms is available 

394 if target_message.meta.latency_ms is not None: 

395 # We need to calculate from first output to usage time 

396 # We'll calculate: usage_time - (sent_at - latency_ms) 

397 # This gives us the time from first output to usage completion 

398 # sent_at is when the message was completed, so sent_at - latency_ms approximates first output time 

399 first_output_time_approx = target_message.meta.sent_at - timedelta(milliseconds=target_message.meta.latency_ms) 

400 output_time_ms = int((usage_time - first_output_time_approx).total_seconds() * 1000) 

401 target_message.meta.total_time_ms = max(0, output_time_ms) 

402 # Always yield usage chunk if it's in includes 

403 if chunk.type in includes: 

404 yield chunk 

405 case _ if chunk.type in includes: 

406 yield chunk 

407 

408 # Finalize assistant message so it can be found in pending function calls 

409 self._finalize_assistant_message() 

410 

411 # Check for pending tool calls after processing current assistant message 

412 pending_tool_calls = self._find_pending_tool_calls() 

413 logger.debug(f"Found {len(pending_tool_calls)} pending tool calls") 

414 if pending_tool_calls: 

415 # Convert to ToolCall format for existing handler 

416 tool_calls = self._convert_tool_calls_to_tool_calls(pending_tool_calls) 

417 require_confirm_tools = await self.agent.list_require_confirm_tools(tool_calls) 

418 if require_confirm_tools: 

419 return 

420 async for tool_chunk in self._handle_tool_calls(tool_calls, includes, context=context): 

421 yield tool_chunk 

422 finish_reason = "tool_calls" 

423 else: 

424 finish_reason = CompletionMode.STOP 

425 steps += 1 

426 

427 async def has_require_confirm_tools(self): 

428 pending_tool_calls = self._find_pending_tool_calls() 

429 if not pending_tool_calls: 

430 return False 

431 tool_calls = self._convert_tool_calls_to_tool_calls(pending_tool_calls) 

432 require_confirm_tools = await self.agent.list_require_confirm_tools(tool_calls) 

433 return bool(require_confirm_tools) 

434 

435 async def run_continue_until_complete( 

436 self, 

437 max_steps: int = 20, 

438 includes: list[AgentChunkType] | None = None, 

439 record_to: PathLike | str | None = None, 

440 ) -> list[AgentChunk]: 

441 """Deprecated: Use run_until_complete(None) instead.""" 

442 warnings.warn( 

443 "run_continue_until_complete is deprecated. Use run_until_complete(None) instead.", 

444 DeprecationWarning, 

445 stacklevel=2, 

446 ) 

447 resp = self.run_continue_stream(max_steps, includes, record_to=record_to) 

448 return await self._collect_all_chunks(resp) 

449 

450 def run_continue_stream( 

451 self, 

452 max_steps: int = 20, 

453 includes: list[AgentChunkType] | None = None, 

454 record_to: PathLike | str | None = None, 

455 context: "Any | None" = None, # noqa: ANN401 

456 ) -> AsyncGenerator[AgentChunk, None]: 

457 """Deprecated: Use run(None) instead.""" 

458 warnings.warn( 

459 "run_continue_stream is deprecated. Use run(None) instead.", 

460 DeprecationWarning, 

461 stacklevel=2, 

462 ) 

463 return self._run_continue_stream(max_steps, includes, record_to=record_to, context=context) 

464 

465 async def _run_continue_stream( 

466 self, 

467 max_steps: int = 20, 

468 includes: Sequence[AgentChunkType] | None = None, 

469 record_to: PathLike | str | None = None, 

470 context: "Any | None" = None, # noqa: ANN401 

471 ) -> AsyncGenerator[AgentChunk, None]: 

472 """Continue running the agent and return a RunResponse object that can be asynchronously iterated for each chunk.""" 

473 includes = self._normalize_includes(includes) 

474 

475 # Find pending tool calls in responses format 

476 pending_tool_calls = self._find_pending_tool_calls() 

477 if pending_tool_calls: 

478 # Convert to ToolCall format for existing handler 

479 tool_calls = self._convert_tool_calls_to_tool_calls(pending_tool_calls) 

480 async for tool_chunk in self._handle_tool_calls(tool_calls, includes, context=context): 

481 yield tool_chunk 

482 async for chunk in self._run(max_steps, includes, self._normalize_record_path(record_to)): 

483 if chunk.type in includes: 

484 yield chunk 

485 else: 

486 # Check if there are any messages and what the last message is 

487 if not self.messages: 

488 msg = "Cannot continue running without a valid last message from the assistant." 

489 raise ValueError(msg) 

490 

491 resp = self._run(max_steps=max_steps, includes=includes, record_to=self._normalize_record_path(record_to), context=context) 

492 async for chunk in resp: 

493 yield chunk 

494 

495 async def run_until_complete( 

496 self, 

497 user_input: UserInput | None = None, 

498 max_steps: int = 20, 

499 includes: list[AgentChunkType] | None = None, 

500 record_to: PathLike | str | None = None, 

501 ) -> list[AgentChunk]: 

502 """Run the agent until it completes and return the final message.""" 

503 resp = self.run(user_input, max_steps, includes, record_to=record_to) 

504 return await self._collect_all_chunks(resp) 

505 

506 def _analyze_last_assistant_message(self) -> tuple[list[AssistantToolCall], dict[str, str]]: 

507 """Analyze the last assistant message and return pending tool calls and tool call map.""" 

508 if not self.messages or not isinstance(self.messages[-1], NewAssistantMessage): 

509 return [], {} 

510 

511 tool_calls = {} 

512 tool_results = set() 

513 tool_call_names = {} 

514 

515 last_message = self.messages[-1] 

516 for content_item in last_message.content: 

517 if isinstance(content_item, AssistantToolCall): 

518 tool_calls[content_item.call_id] = content_item 

519 tool_call_names[content_item.call_id] = content_item.name 

520 elif isinstance(content_item, AssistantToolCallResult): 

521 tool_results.add(content_item.call_id) 

522 

523 # Return pending tool calls and tool call names map 

524 pending_calls = [call for call_id, call in tool_calls.items() if call_id not in tool_results] 

525 return pending_calls, tool_call_names 

526 

527 def _find_pending_tool_calls(self) -> list[AssistantToolCall]: 

528 """Find tool calls that don't have corresponding results yet.""" 

529 pending_calls, _ = self._analyze_last_assistant_message() 

530 return pending_calls 

531 

532 def _get_tool_call_name_by_id(self, call_id: str) -> str | None: 

533 """Get the tool name for a given call_id from the last assistant message.""" 

534 _, tool_call_names = self._analyze_last_assistant_message() 

535 return tool_call_names.get(call_id) 

536 

537 def _cancel_pending_tool_calls(self) -> list[FunctionCallOutputEvent]: 

538 """Cancel all pending tool calls by adding cancellation results. 

539 

540 Returns: 

541 List of FunctionCallOutputEvent for each cancelled tool call 

542 """ 

543 pending_tool_calls = self._find_pending_tool_calls() 

544 if not pending_tool_calls: 

545 return [] 

546 

547 logger.debug(f"Cancelling {len(pending_tool_calls)} pending tool calls due to new user input") 

548 

549 cancellation_events = [] 

550 for tool_call in pending_tool_calls: 

551 output = "Operation cancelled by user - new input provided" 

552 self._add_tool_call_result( 

553 call_id=tool_call.call_id, 

554 output=output, 

555 execution_time_ms=0, 

556 ) 

557 

558 # Create cancellation event 

559 cancellation_event = FunctionCallOutputEvent( 

560 tool_call_id=tool_call.call_id, 

561 name=tool_call.name, 

562 content=output, 

563 execution_time_ms=0, 

564 ) 

565 cancellation_events.append(cancellation_event) 

566 

567 return cancellation_events 

568 

569 def _convert_tool_calls_to_tool_calls(self, tool_calls: list[AssistantToolCall]) -> list[ToolCall]: 

570 """Convert AssistantToolCall objects to ToolCall objects for compatibility.""" 

571 return [ 

572 ToolCall( 

573 id=tc.call_id, 

574 type="function", 

575 function=ToolCallFunction( 

576 name=tc.name, 

577 arguments=tc.arguments if isinstance(tc.arguments, str) else str(tc.arguments), 

578 ), 

579 index=i, 

580 ) 

581 for i, tc in enumerate(tool_calls) 

582 ] 

583 

584 def set_chat_history(self, messages: Sequence[FlexibleInputMessage], root_agent: Agent | None = None) -> None: 

585 """Set the entire chat history and track the current agent based on function calls. 

586 

587 This method analyzes the message history to determine which agent should be active 

588 based on transfer_to_agent and transfer_to_parent function calls. 

589 

590 Args: 

591 messages: List of messages to set as the chat history 

592 root_agent: The root agent to use if no transfers are found. If None, uses self.agent 

593 """ 

594 # Clear current messages 

595 self.messages.clear() 

596 

597 # Set initial agent 

598 current_agent = root_agent if root_agent is not None else self.agent 

599 

600 # Add each message and track agent transfers 

601 for input_message in messages: 

602 # Store length before adding to get the added message 

603 prev_length = len(self.messages) 

604 self.append_message(input_message) 

605 

606 # Track transfers using the converted message (now in self.messages) 

607 if len(self.messages) > prev_length: 

608 converted_message = self.messages[-1] # Get the last added message 

609 current_agent = self._track_agent_transfer_in_message(converted_message, current_agent) 

610 

611 # Set the current agent based on the tracked transfers 

612 self.agent = current_agent 

613 logger.info(f"Chat history set with {len(self.messages)} messages. Current agent: {self.agent.name}") 

614 

615 def get_messages(self) -> list[NewMessage]: 

616 """Get the messages as NewMessage objects. 

617 

618 Only returns NewMessage objects, filtering out any dict or other legacy formats. 

619 """ 

620 return [msg for msg in self.messages if isinstance(msg, NewMessage)] 

621 

622 def get_dict_messages(self) -> list[dict[str, Any]]: 

623 """Get the messages in JSONL format.""" 

624 result = [] 

625 for msg in self.messages: 

626 if hasattr(msg, "model_dump"): 

627 result.append(msg.model_dump(mode="json")) 

628 elif isinstance(msg, dict): 

629 result.append(msg) 

630 else: 

631 # Fallback for any other message types 

632 result.append(dict(msg)) 

633 return result 

634 

635 def add_user_message(self, text: str) -> None: 

636 """Convenience method to add a user text message.""" 

637 message = NewUserMessage(content=[UserTextContent(text=text)]) 

638 self.append_message(message) 

639 

640 def add_assistant_message(self, text: str) -> None: 

641 """Convenience method to add an assistant text message.""" 

642 message = NewAssistantMessage(content=[AssistantTextContent(text=text)]) 

643 self.append_message(message) 

644 

645 def add_system_message(self, content: str) -> None: 

646 """Convenience method to add a system message.""" 

647 message = NewSystemMessage(content=content) 

648 self.append_message(message) 

649 

650 def _track_agent_transfer_in_message(self, message: FlexibleRunnerMessage, current_agent: Agent) -> Agent: 

651 """Track agent transfers in a single message. 

652 

653 Args: 

654 message: The message to analyze for transfers 

655 current_agent: The currently active agent 

656 

657 Returns: 

658 The agent that should be active after processing this message 

659 """ 

660 if isinstance(message, NewAssistantMessage): 

661 return self._track_transfer_from_new_assistant_message(message, current_agent) 

662 

663 return current_agent 

664 

665 def _track_transfer_from_new_assistant_message(self, message: NewAssistantMessage, current_agent: Agent) -> Agent: 

666 """Track transfers from NewAssistantMessage objects.""" 

667 for content_item in message.content: 

668 if content_item.type == "tool_call": 

669 if content_item.name == ToolName.TRANSFER_TO_AGENT: 

670 arguments = content_item.arguments if isinstance(content_item.arguments, str) else str(content_item.arguments) 

671 return self._handle_transfer_to_agent_tracking(arguments, current_agent) 

672 if content_item.name == ToolName.TRANSFER_TO_PARENT: 

673 return self._handle_transfer_to_parent_tracking(current_agent) 

674 return current_agent 

675 

676 def _handle_transfer_to_agent_tracking(self, arguments: str | dict, current_agent: Agent) -> Agent: 

677 """Handle transfer_to_agent function call tracking.""" 

678 try: 

679 args_dict = json.loads(arguments) if isinstance(arguments, str) else arguments 

680 

681 target_agent_name = args_dict.get("name") 

682 if target_agent_name: 

683 target_agent = self._find_agent_by_name(current_agent, target_agent_name) 

684 if target_agent: 

685 logger.debug(f"History tracking: Transferring from {current_agent.name} to {target_agent_name}") 

686 return target_agent 

687 

688 logger.warning(f"Target agent '{target_agent_name}' not found in handoffs during history setup") 

689 except (json.JSONDecodeError, KeyError, TypeError) as e: 

690 logger.warning(f"Failed to parse transfer_to_agent arguments during history setup: {e}") 

691 

692 return current_agent 

693 

694 def _handle_transfer_to_parent_tracking(self, current_agent: Agent) -> Agent: 

695 """Handle transfer_to_parent function call tracking.""" 

696 if current_agent.parent: 

697 logger.debug(f"History tracking: Transferring from {current_agent.name} back to parent {current_agent.parent.name}") 

698 return current_agent.parent 

699 

700 logger.warning(f"Agent {current_agent.name} has no parent to transfer back to during history setup") 

701 return current_agent 

702 

703 def _find_agent_by_name(self, root_agent: Agent, target_name: str) -> Agent | None: 

704 """Find an agent by name in the handoffs tree starting from root_agent. 

705 

706 Args: 

707 root_agent: The root agent to start searching from 

708 target_name: The name of the agent to find 

709 

710 Returns: 

711 The agent if found, None otherwise 

712 """ 

713 # Check direct handoffs from current agent 

714 if root_agent.handoffs: 

715 for agent in root_agent.handoffs: 

716 if agent.name == target_name: 

717 return agent 

718 

719 # If not found in direct handoffs, check if we need to look in parent's handoffs 

720 # This handles cases where agents can transfer to siblings 

721 current = root_agent 

722 while current.parent is not None: 

723 current = current.parent 

724 if current.handoffs: 

725 for agent in current.handoffs: 

726 if agent.name == target_name: 

727 return agent 

728 

729 return None 

730 

731 def append_message(self, message: FlexibleInputMessage) -> None: 

732 """Append a message to the conversation history. 

733 

734 Accepts both NewMessage format and dict format (which will be converted internally). 

735 """ 

736 if isinstance(message, NewMessage): 

737 self.messages.append(message) 

738 elif isinstance(message, dict): 

739 # Convert dict to NewMessage using MessageBuilder 

740 role = message.get("role", "").lower() 

741 if role == "user": 

742 converted_message = MessageBuilder.build_user_message_from_dict(message) 

743 elif role == "assistant": 

744 converted_message = MessageBuilder.build_assistant_message_from_dict(message) 

745 elif role == "system": 

746 converted_message = MessageBuilder.build_system_message_from_dict(message) 

747 else: 

748 msg = f"Unsupported message role: {role}. Must be 'user', 'assistant', or 'system'." 

749 raise ValueError(msg) 

750 

751 self.messages.append(converted_message) 

752 else: 

753 msg = f"Unsupported message type: {type(message)}. Supports NewMessage types and dict." 

754 raise TypeError(msg) 

755 

756 async def _handle_agent_transfer(self, tool_call: ToolCall) -> tuple[str, str]: 

757 """Handle agent transfer when transfer_to_agent tool is called. 

758 

759 Args: 

760 tool_call: The transfer_to_agent tool call 

761 

762 Returns: 

763 Tuple of (call_id, output) for the tool call result 

764 """ 

765 

766 # Parse the arguments to get the target agent name 

767 try: 

768 arguments = json.loads(tool_call.function.arguments or "{}") 

769 target_agent_name = arguments.get("name") 

770 except (json.JSONDecodeError, KeyError): 

771 logger.error("Failed to parse transfer_to_agent arguments: %s", tool_call.function.arguments) 

772 output = "Failed to parse transfer arguments" 

773 # Add error result to messages 

774 self._add_tool_call_result( 

775 call_id=tool_call.id, 

776 output=output, 

777 ) 

778 return tool_call.id, output 

779 

780 if not target_agent_name: 

781 logger.error("No target agent name provided in transfer_to_agent call") 

782 output = "No target agent name provided" 

783 # Add error result to messages 

784 self._add_tool_call_result( 

785 call_id=tool_call.id, 

786 output=output, 

787 ) 

788 return tool_call.id, output 

789 

790 # Find the target agent in handoffs 

791 if not self.agent.handoffs: 

792 logger.error("Current agent has no handoffs configured") 

793 output = "Current agent has no handoffs configured" 

794 # Add error result to messages 

795 self._add_tool_call_result( 

796 call_id=tool_call.id, 

797 output=output, 

798 ) 

799 return tool_call.id, output 

800 

801 target_agent = None 

802 for agent in self.agent.handoffs: 

803 if agent.name == target_agent_name: 

804 target_agent = agent 

805 break 

806 

807 if not target_agent: 

808 logger.error("Target agent '%s' not found in handoffs", target_agent_name) 

809 output = f"Target agent '{target_agent_name}' not found in handoffs" 

810 # Add error result to messages 

811 self._add_tool_call_result( 

812 call_id=tool_call.id, 

813 output=output, 

814 ) 

815 return tool_call.id, output 

816 

817 # Execute the transfer tool call to get the result 

818 try: 

819 result = await self.agent.fc.call_function_async( 

820 tool_call.function.name, 

821 tool_call.function.arguments or "", 

822 ) 

823 

824 output = str(result) 

825 # Add the tool call result to messages 

826 self._add_tool_call_result( 

827 call_id=tool_call.id, 

828 output=output, 

829 ) 

830 

831 # Switch to the target agent 

832 logger.info("Transferring conversation from %s to %s", self.agent.name, target_agent_name) 

833 self.agent = target_agent 

834 

835 except Exception as e: 

836 logger.exception("Failed to execute transfer_to_agent tool call") 

837 output = f"Transfer failed: {e!s}" 

838 # Add error result to messages 

839 self._add_tool_call_result( 

840 call_id=tool_call.id, 

841 output=output, 

842 ) 

843 return tool_call.id, output 

844 else: 

845 return tool_call.id, output 

846 

847 async def _handle_parent_transfer(self, tool_call: ToolCall) -> tuple[str, str]: 

848 """Handle parent transfer when transfer_to_parent tool is called. 

849 

850 Args: 

851 tool_call: The transfer_to_parent tool call 

852 

853 Returns: 

854 Tuple of (call_id, output) for the tool call result 

855 """ 

856 

857 # Check if current agent has a parent 

858 if not self.agent.parent: 

859 logger.error("Current agent has no parent to transfer back to.") 

860 output = "Current agent has no parent to transfer back to" 

861 # Add error result to messages 

862 self._add_tool_call_result( 

863 call_id=tool_call.id, 

864 output=output, 

865 ) 

866 return tool_call.id, output 

867 

868 # Execute the transfer tool call to get the result 

869 try: 

870 result = await self.agent.fc.call_function_async( 

871 tool_call.function.name, 

872 tool_call.function.arguments or "", 

873 ) 

874 

875 output = str(result) 

876 # Add the tool call result to messages 

877 self._add_tool_call_result( 

878 call_id=tool_call.id, 

879 output=output, 

880 ) 

881 

882 # Switch to the parent agent 

883 logger.info("Transferring conversation from %s back to parent %s", self.agent.name, self.agent.parent.name) 

884 self.agent = self.agent.parent 

885 

886 except Exception as e: 

887 logger.exception("Failed to execute transfer_to_parent tool call") 

888 output = f"Transfer to parent failed: {e!s}" 

889 # Add error result to messages 

890 self._add_tool_call_result( 

891 call_id=tool_call.id, 

892 output=output, 

893 ) 

894 return tool_call.id, output 

895 else: 

896 return tool_call.id, output