Coverage for src/lite_agent/processors/completion_event_processor.py: 81%
166 statements
« prev ^ index » next coverage.py v7.10.5, created at 2025-08-25 22:58 +0900
« prev ^ index » next coverage.py v7.10.5, created at 2025-08-25 22:58 +0900
1from collections.abc import AsyncGenerator
2from datetime import datetime, timezone
3from typing import Literal
5import litellm
6from aiofiles.threadpool.text import AsyncTextIOWrapper
7from litellm.types.utils import ChatCompletionDeltaToolCall, ModelResponseStream, StreamingChoices
9from lite_agent.loggers import logger
10from lite_agent.types import (
11 AgentChunk,
12 AssistantMessage,
13 AssistantMessageEvent,
14 AssistantMessageMeta,
15 AssistantTextContent,
16 CompletionRawEvent,
17 ContentDeltaEvent,
18 EventUsage,
19 FunctionCallDeltaEvent,
20 FunctionCallEvent,
21 MessageUsage,
22 NewAssistantMessage,
23 Timing,
24 TimingEvent,
25 ToolCall,
26 ToolCallFunction,
27 UsageEvent,
28)
29from lite_agent.utils.metrics import TimingMetrics
32class CompletionEventProcessor:
33 """Processor for handling completion event"""
35 def __init__(self) -> None:
36 self._current_message: AssistantMessage | None = None
37 self.processing_chunk: Literal["content", "tool_calls"] | None = None
38 self.processing_function: str | None = None
39 self.last_processed_chunk: ModelResponseStream | None = None
40 self.yielded_content = False
41 self.yielded_function = set()
42 self._start_time: datetime | None = None
43 self._first_output_time: datetime | None = None
44 self._output_complete_time: datetime | None = None
45 self._usage_time: datetime | None = None
46 self._usage_data: dict[str, int] = {}
48 async def process_chunk(
49 self,
50 chunk: ModelResponseStream,
51 record_file: AsyncTextIOWrapper | None = None,
52 ) -> AsyncGenerator[AgentChunk, None]:
53 # Mark start time on first chunk
54 if self._start_time is None:
55 self._start_time = datetime.now(timezone.utc)
57 if record_file:
58 await record_file.write(chunk.model_dump_json() + "\n")
59 await record_file.flush()
60 yield CompletionRawEvent(raw=chunk)
61 usage_chunks = self.handle_usage_chunk(chunk)
62 if usage_chunks:
63 for usage_chunk in usage_chunks:
64 yield usage_chunk
65 return
66 if not chunk.choices:
67 return
69 choice = chunk.choices[0]
70 delta = choice.delta
71 if delta.tool_calls:
72 if not self.yielded_content:
73 self.yielded_content = True
74 end_time = datetime.now(timezone.utc)
75 latency_ms = TimingMetrics.calculate_latency_ms(self._start_time, self._first_output_time)
76 output_time_ms = TimingMetrics.calculate_output_time_ms(self._first_output_time, self._output_complete_time)
78 usage = MessageUsage(
79 input_tokens=self._usage_data.get("input_tokens"),
80 output_tokens=self._usage_data.get("output_tokens"),
81 )
82 # Extract model information from chunk
83 model_name = getattr(chunk, "model", None)
84 meta = AssistantMessageMeta(
85 sent_at=end_time,
86 model=model_name,
87 latency_ms=latency_ms,
88 total_time_ms=output_time_ms,
89 usage=usage,
90 )
91 # Include accumulated text content in the message
92 content = []
93 if self._current_message and self._current_message.content:
94 content.append(AssistantTextContent(text=self._current_message.content))
96 yield AssistantMessageEvent(
97 message=NewAssistantMessage(
98 content=content,
99 meta=meta,
100 ),
101 )
102 first_tool_call = delta.tool_calls[0]
103 tool_name = first_tool_call.function.name if first_tool_call.function else ""
104 if tool_name:
105 self.processing_function = tool_name
106 delta = choice.delta
107 if (
108 self._current_message
109 and self._current_message.tool_calls
110 and self.processing_function != self._current_message.tool_calls[-1].function.name
111 and self._current_message.tool_calls[-1].function.name not in self.yielded_function
112 ):
113 tool_call = self._current_message.tool_calls[-1]
114 yield FunctionCallEvent(
115 call_id=tool_call.id,
116 name=tool_call.function.name,
117 arguments=tool_call.function.arguments or "",
118 )
119 self.yielded_function.add(tool_call.function.name)
120 if not self.is_initialized:
121 self.initialize_message(chunk, choice)
122 if delta.content and self._current_message:
123 # Mark first output time if not already set
124 if self._first_output_time is None:
125 self._first_output_time = datetime.now(timezone.utc)
126 self._current_message.content += delta.content
127 yield ContentDeltaEvent(delta=delta.content)
128 if delta.tool_calls is not None:
129 self.update_tool_calls(delta.tool_calls)
130 if delta.tool_calls and self.current_message.tool_calls:
131 tool_call = delta.tool_calls[0]
132 message_tool_call = self.current_message.tool_calls[-1]
133 yield FunctionCallDeltaEvent(
134 tool_call_id=message_tool_call.id,
135 name=message_tool_call.function.name,
136 arguments_delta=tool_call.function.arguments or "",
137 )
138 if choice.finish_reason:
139 # Mark output complete time when finish_reason appears
140 if self._output_complete_time is None:
141 self._output_complete_time = datetime.now(timezone.utc)
143 if self.current_message.tool_calls:
144 tool_call = self.current_message.tool_calls[-1]
145 yield FunctionCallEvent(
146 call_id=tool_call.id,
147 name=tool_call.function.name,
148 arguments=tool_call.function.arguments or "",
149 )
150 if not self.yielded_content:
151 self.yielded_content = True
152 end_time = datetime.now(timezone.utc)
153 latency_ms = TimingMetrics.calculate_latency_ms(self._start_time, self._first_output_time)
154 output_time_ms = TimingMetrics.calculate_output_time_ms(self._first_output_time, self._output_complete_time)
156 usage = MessageUsage(
157 input_tokens=self._usage_data.get("input_tokens"),
158 output_tokens=self._usage_data.get("output_tokens"),
159 )
160 # Extract model information from chunk
161 model_name = getattr(chunk, "model", None)
162 meta = AssistantMessageMeta(
163 sent_at=end_time,
164 model=model_name,
165 latency_ms=latency_ms,
166 total_time_ms=output_time_ms,
167 usage=usage,
168 )
169 # Include accumulated text content in the message
170 content = []
171 if self._current_message and self._current_message.content:
172 content.append(AssistantTextContent(text=self._current_message.content))
174 yield AssistantMessageEvent(
175 message=NewAssistantMessage(
176 content=content,
177 meta=meta,
178 ),
179 )
180 self.last_processed_chunk = chunk
182 def handle_usage_chunk(self, chunk: ModelResponseStream) -> list[AgentChunk]:
183 usage = getattr(chunk, "usage", None)
184 if usage:
185 # Mark usage time
186 self._usage_time = datetime.now(timezone.utc)
187 # Store usage data for meta information
188 self._usage_data["input_tokens"] = usage["prompt_tokens"]
189 self._usage_data["output_tokens"] = usage["completion_tokens"]
191 results = []
193 # First yield usage event
194 results.append(UsageEvent(usage=EventUsage(input_tokens=usage["prompt_tokens"], output_tokens=usage["completion_tokens"])))
196 # Then yield timing event if we have timing data
197 latency_ms = TimingMetrics.calculate_latency_ms(self._start_time, self._first_output_time)
198 output_time_ms = TimingMetrics.calculate_output_time_ms(self._first_output_time, self._output_complete_time)
199 if latency_ms is not None and output_time_ms is not None:
200 results.append(
201 TimingEvent(
202 timing=Timing(
203 latency_ms=latency_ms,
204 output_time_ms=output_time_ms,
205 ),
206 ),
207 )
209 return results
210 return []
212 def initialize_message(self, chunk: ModelResponseStream, choice: StreamingChoices) -> None:
213 """Initialize the message object"""
214 delta = choice.delta
215 if delta.role != "assistant":
216 logger.warning("Skipping chunk with role: %s", delta.role)
217 return
218 self._current_message = AssistantMessage(
219 id=chunk.id,
220 index=choice.index,
221 role=delta.role,
222 content="",
223 )
224 logger.debug('Initialized new message: "%s"', self._current_message.id)
226 def update_content(self, content: str) -> None:
227 """Update message content"""
228 if self._current_message and content:
229 self._current_message.content += content
231 def _initialize_tool_calls(self, tool_calls: list[litellm.ChatCompletionMessageToolCall]) -> None:
232 """Initialize tool calls"""
233 if not self._current_message:
234 return
236 self._current_message.tool_calls = []
237 for call in tool_calls:
238 logger.debug("Create new tool call: %s", call.id)
240 def _update_tool_calls(self, tool_calls: list[litellm.ChatCompletionMessageToolCall]) -> None:
241 """Update existing tool calls"""
242 if not self._current_message:
243 return
244 if not hasattr(self._current_message, "tool_calls"):
245 self._current_message.tool_calls = []
246 if not self._current_message.tool_calls:
247 return
248 if not tool_calls:
249 return
250 for current_call, new_call in zip(self._current_message.tool_calls, tool_calls, strict=False):
251 if new_call.function.arguments and current_call.function.arguments:
252 current_call.function.arguments += new_call.function.arguments
253 if new_call.type and new_call.type == "function":
254 current_call.type = new_call.type
255 elif new_call.type:
256 logger.warning("Unexpected tool call type: %s", new_call.type)
258 def update_tool_calls(self, tool_calls: list[ChatCompletionDeltaToolCall]) -> None:
259 """Handle tool call updates"""
260 if not tool_calls:
261 return
262 for call in tool_calls:
263 if call.id:
264 if call.type == "function":
265 new_tool_call = ToolCall(
266 id=call.id,
267 type=call.type,
268 function=ToolCallFunction(
269 name=call.function.name or "",
270 arguments=call.function.arguments,
271 ),
272 index=call.index,
273 )
274 if self._current_message is not None:
275 if self._current_message.tool_calls is None:
276 self._current_message.tool_calls = []
277 self._current_message.tool_calls.append(new_tool_call)
278 else:
279 logger.warning("Unexpected tool call type: %s", call.type)
280 elif self._current_message is not None and self._current_message.tool_calls is not None and call.index is not None and 0 <= call.index < len(self._current_message.tool_calls):
281 existing_call = self._current_message.tool_calls[call.index]
282 if call.function.arguments:
283 if existing_call.function.arguments is None:
284 existing_call.function.arguments = ""
285 existing_call.function.arguments += call.function.arguments
286 else:
287 logger.warning("Cannot update tool call: current_message or tool_calls is None, or invalid index.")
289 @property
290 def is_initialized(self) -> bool:
291 """Check if the current message is initialized"""
292 return self._current_message is not None
294 @property
295 def current_message(self) -> AssistantMessage:
296 """Get the current message being processed"""
297 if not self._current_message:
298 msg = "No current message initialized. Call initialize_message first."
299 raise ValueError(msg)
300 return self._current_message