Coverage for src/lite_agent/processors/completion_event_processor.py: 81%

166 statements  

« prev     ^ index     » next       coverage.py v7.10.5, created at 2025-08-25 22:58 +0900

1from collections.abc import AsyncGenerator 

2from datetime import datetime, timezone 

3from typing import Literal 

4 

5import litellm 

6from aiofiles.threadpool.text import AsyncTextIOWrapper 

7from litellm.types.utils import ChatCompletionDeltaToolCall, ModelResponseStream, StreamingChoices 

8 

9from lite_agent.loggers import logger 

10from lite_agent.types import ( 

11 AgentChunk, 

12 AssistantMessage, 

13 AssistantMessageEvent, 

14 AssistantMessageMeta, 

15 AssistantTextContent, 

16 CompletionRawEvent, 

17 ContentDeltaEvent, 

18 EventUsage, 

19 FunctionCallDeltaEvent, 

20 FunctionCallEvent, 

21 MessageUsage, 

22 NewAssistantMessage, 

23 Timing, 

24 TimingEvent, 

25 ToolCall, 

26 ToolCallFunction, 

27 UsageEvent, 

28) 

29from lite_agent.utils.metrics import TimingMetrics 

30 

31 

32class CompletionEventProcessor: 

33 """Processor for handling completion event""" 

34 

35 def __init__(self) -> None: 

36 self._current_message: AssistantMessage | None = None 

37 self.processing_chunk: Literal["content", "tool_calls"] | None = None 

38 self.processing_function: str | None = None 

39 self.last_processed_chunk: ModelResponseStream | None = None 

40 self.yielded_content = False 

41 self.yielded_function = set() 

42 self._start_time: datetime | None = None 

43 self._first_output_time: datetime | None = None 

44 self._output_complete_time: datetime | None = None 

45 self._usage_time: datetime | None = None 

46 self._usage_data: dict[str, int] = {} 

47 

48 async def process_chunk( 

49 self, 

50 chunk: ModelResponseStream, 

51 record_file: AsyncTextIOWrapper | None = None, 

52 ) -> AsyncGenerator[AgentChunk, None]: 

53 # Mark start time on first chunk 

54 if self._start_time is None: 

55 self._start_time = datetime.now(timezone.utc) 

56 

57 if record_file: 

58 await record_file.write(chunk.model_dump_json() + "\n") 

59 await record_file.flush() 

60 yield CompletionRawEvent(raw=chunk) 

61 usage_chunks = self.handle_usage_chunk(chunk) 

62 if usage_chunks: 

63 for usage_chunk in usage_chunks: 

64 yield usage_chunk 

65 return 

66 if not chunk.choices: 

67 return 

68 

69 choice = chunk.choices[0] 

70 delta = choice.delta 

71 if delta.tool_calls: 

72 if not self.yielded_content: 

73 self.yielded_content = True 

74 end_time = datetime.now(timezone.utc) 

75 latency_ms = TimingMetrics.calculate_latency_ms(self._start_time, self._first_output_time) 

76 output_time_ms = TimingMetrics.calculate_output_time_ms(self._first_output_time, self._output_complete_time) 

77 

78 usage = MessageUsage( 

79 input_tokens=self._usage_data.get("input_tokens"), 

80 output_tokens=self._usage_data.get("output_tokens"), 

81 ) 

82 # Extract model information from chunk 

83 model_name = getattr(chunk, "model", None) 

84 meta = AssistantMessageMeta( 

85 sent_at=end_time, 

86 model=model_name, 

87 latency_ms=latency_ms, 

88 total_time_ms=output_time_ms, 

89 usage=usage, 

90 ) 

91 # Include accumulated text content in the message 

92 content = [] 

93 if self._current_message and self._current_message.content: 

94 content.append(AssistantTextContent(text=self._current_message.content)) 

95 

96 yield AssistantMessageEvent( 

97 message=NewAssistantMessage( 

98 content=content, 

99 meta=meta, 

100 ), 

101 ) 

102 first_tool_call = delta.tool_calls[0] 

103 tool_name = first_tool_call.function.name if first_tool_call.function else "" 

104 if tool_name: 

105 self.processing_function = tool_name 

106 delta = choice.delta 

107 if ( 

108 self._current_message 

109 and self._current_message.tool_calls 

110 and self.processing_function != self._current_message.tool_calls[-1].function.name 

111 and self._current_message.tool_calls[-1].function.name not in self.yielded_function 

112 ): 

113 tool_call = self._current_message.tool_calls[-1] 

114 yield FunctionCallEvent( 

115 call_id=tool_call.id, 

116 name=tool_call.function.name, 

117 arguments=tool_call.function.arguments or "", 

118 ) 

119 self.yielded_function.add(tool_call.function.name) 

120 if not self.is_initialized: 

121 self.initialize_message(chunk, choice) 

122 if delta.content and self._current_message: 

123 # Mark first output time if not already set 

124 if self._first_output_time is None: 

125 self._first_output_time = datetime.now(timezone.utc) 

126 self._current_message.content += delta.content 

127 yield ContentDeltaEvent(delta=delta.content) 

128 if delta.tool_calls is not None: 

129 self.update_tool_calls(delta.tool_calls) 

130 if delta.tool_calls and self.current_message.tool_calls: 

131 tool_call = delta.tool_calls[0] 

132 message_tool_call = self.current_message.tool_calls[-1] 

133 yield FunctionCallDeltaEvent( 

134 tool_call_id=message_tool_call.id, 

135 name=message_tool_call.function.name, 

136 arguments_delta=tool_call.function.arguments or "", 

137 ) 

138 if choice.finish_reason: 

139 # Mark output complete time when finish_reason appears 

140 if self._output_complete_time is None: 

141 self._output_complete_time = datetime.now(timezone.utc) 

142 

143 if self.current_message.tool_calls: 

144 tool_call = self.current_message.tool_calls[-1] 

145 yield FunctionCallEvent( 

146 call_id=tool_call.id, 

147 name=tool_call.function.name, 

148 arguments=tool_call.function.arguments or "", 

149 ) 

150 if not self.yielded_content: 

151 self.yielded_content = True 

152 end_time = datetime.now(timezone.utc) 

153 latency_ms = TimingMetrics.calculate_latency_ms(self._start_time, self._first_output_time) 

154 output_time_ms = TimingMetrics.calculate_output_time_ms(self._first_output_time, self._output_complete_time) 

155 

156 usage = MessageUsage( 

157 input_tokens=self._usage_data.get("input_tokens"), 

158 output_tokens=self._usage_data.get("output_tokens"), 

159 ) 

160 # Extract model information from chunk 

161 model_name = getattr(chunk, "model", None) 

162 meta = AssistantMessageMeta( 

163 sent_at=end_time, 

164 model=model_name, 

165 latency_ms=latency_ms, 

166 total_time_ms=output_time_ms, 

167 usage=usage, 

168 ) 

169 # Include accumulated text content in the message 

170 content = [] 

171 if self._current_message and self._current_message.content: 

172 content.append(AssistantTextContent(text=self._current_message.content)) 

173 

174 yield AssistantMessageEvent( 

175 message=NewAssistantMessage( 

176 content=content, 

177 meta=meta, 

178 ), 

179 ) 

180 self.last_processed_chunk = chunk 

181 

182 def handle_usage_chunk(self, chunk: ModelResponseStream) -> list[AgentChunk]: 

183 usage = getattr(chunk, "usage", None) 

184 if usage: 

185 # Mark usage time 

186 self._usage_time = datetime.now(timezone.utc) 

187 # Store usage data for meta information 

188 self._usage_data["input_tokens"] = usage["prompt_tokens"] 

189 self._usage_data["output_tokens"] = usage["completion_tokens"] 

190 

191 results = [] 

192 

193 # First yield usage event 

194 results.append(UsageEvent(usage=EventUsage(input_tokens=usage["prompt_tokens"], output_tokens=usage["completion_tokens"]))) 

195 

196 # Then yield timing event if we have timing data 

197 latency_ms = TimingMetrics.calculate_latency_ms(self._start_time, self._first_output_time) 

198 output_time_ms = TimingMetrics.calculate_output_time_ms(self._first_output_time, self._output_complete_time) 

199 if latency_ms is not None and output_time_ms is not None: 

200 results.append( 

201 TimingEvent( 

202 timing=Timing( 

203 latency_ms=latency_ms, 

204 output_time_ms=output_time_ms, 

205 ), 

206 ), 

207 ) 

208 

209 return results 

210 return [] 

211 

212 def initialize_message(self, chunk: ModelResponseStream, choice: StreamingChoices) -> None: 

213 """Initialize the message object""" 

214 delta = choice.delta 

215 if delta.role != "assistant": 

216 logger.warning("Skipping chunk with role: %s", delta.role) 

217 return 

218 self._current_message = AssistantMessage( 

219 id=chunk.id, 

220 index=choice.index, 

221 role=delta.role, 

222 content="", 

223 ) 

224 logger.debug('Initialized new message: "%s"', self._current_message.id) 

225 

226 def update_content(self, content: str) -> None: 

227 """Update message content""" 

228 if self._current_message and content: 

229 self._current_message.content += content 

230 

231 def _initialize_tool_calls(self, tool_calls: list[litellm.ChatCompletionMessageToolCall]) -> None: 

232 """Initialize tool calls""" 

233 if not self._current_message: 

234 return 

235 

236 self._current_message.tool_calls = [] 

237 for call in tool_calls: 

238 logger.debug("Create new tool call: %s", call.id) 

239 

240 def _update_tool_calls(self, tool_calls: list[litellm.ChatCompletionMessageToolCall]) -> None: 

241 """Update existing tool calls""" 

242 if not self._current_message: 

243 return 

244 if not hasattr(self._current_message, "tool_calls"): 

245 self._current_message.tool_calls = [] 

246 if not self._current_message.tool_calls: 

247 return 

248 if not tool_calls: 

249 return 

250 for current_call, new_call in zip(self._current_message.tool_calls, tool_calls, strict=False): 

251 if new_call.function.arguments and current_call.function.arguments: 

252 current_call.function.arguments += new_call.function.arguments 

253 if new_call.type and new_call.type == "function": 

254 current_call.type = new_call.type 

255 elif new_call.type: 

256 logger.warning("Unexpected tool call type: %s", new_call.type) 

257 

258 def update_tool_calls(self, tool_calls: list[ChatCompletionDeltaToolCall]) -> None: 

259 """Handle tool call updates""" 

260 if not tool_calls: 

261 return 

262 for call in tool_calls: 

263 if call.id: 

264 if call.type == "function": 

265 new_tool_call = ToolCall( 

266 id=call.id, 

267 type=call.type, 

268 function=ToolCallFunction( 

269 name=call.function.name or "", 

270 arguments=call.function.arguments, 

271 ), 

272 index=call.index, 

273 ) 

274 if self._current_message is not None: 

275 if self._current_message.tool_calls is None: 

276 self._current_message.tool_calls = [] 

277 self._current_message.tool_calls.append(new_tool_call) 

278 else: 

279 logger.warning("Unexpected tool call type: %s", call.type) 

280 elif self._current_message is not None and self._current_message.tool_calls is not None and call.index is not None and 0 <= call.index < len(self._current_message.tool_calls): 

281 existing_call = self._current_message.tool_calls[call.index] 

282 if call.function.arguments: 

283 if existing_call.function.arguments is None: 

284 existing_call.function.arguments = "" 

285 existing_call.function.arguments += call.function.arguments 

286 else: 

287 logger.warning("Cannot update tool call: current_message or tool_calls is None, or invalid index.") 

288 

289 @property 

290 def is_initialized(self) -> bool: 

291 """Check if the current message is initialized""" 

292 return self._current_message is not None 

293 

294 @property 

295 def current_message(self) -> AssistantMessage: 

296 """Get the current message being processed""" 

297 if not self._current_message: 

298 msg = "No current message initialized. Call initialize_message first." 

299 raise ValueError(msg) 

300 return self._current_message