Coverage for src/lite_agent/processors/response_event_processor.py: 97%

105 statements  

« prev     ^ index     » next       coverage.py v7.10.5, created at 2025-08-25 22:58 +0900

1from collections.abc import AsyncGenerator 

2from datetime import datetime, timezone 

3from typing import Any 

4 

5from aiofiles.threadpool.text import AsyncTextIOWrapper 

6from litellm.types.llms.openai import ( 

7 ContentPartAddedEvent, 

8 FunctionCallArgumentsDeltaEvent, 

9 FunctionCallArgumentsDoneEvent, 

10 OutputItemAddedEvent, 

11 OutputItemDoneEvent, 

12 OutputTextDeltaEvent, 

13 ResponseCompletedEvent, 

14 ResponsesAPIStreamEvents, 

15 ResponsesAPIStreamingResponse, 

16) 

17 

18from lite_agent.types import ( 

19 AgentChunk, 

20 AssistantMessageEvent, 

21 AssistantMessageMeta, 

22 ContentDeltaEvent, 

23 EventUsage, 

24 FunctionCallEvent, 

25 MessageUsage, 

26 NewAssistantMessage, 

27 ResponseRawEvent, 

28 Timing, 

29 TimingEvent, 

30 UsageEvent, 

31) 

32from lite_agent.utils.metrics import TimingMetrics 

33 

34 

35class ResponseEventProcessor: 

36 """Processor for handling response events""" 

37 

38 def __init__(self) -> None: 

39 self._messages: list[dict[str, Any]] = [] 

40 self._start_time: datetime | None = None 

41 self._first_output_time: datetime | None = None 

42 self._output_complete_time: datetime | None = None 

43 self._usage_time: datetime | None = None 

44 self._usage_data: dict[str, Any] = {} 

45 

46 async def process_chunk( 

47 self, 

48 chunk: ResponsesAPIStreamingResponse, 

49 record_file: AsyncTextIOWrapper | None = None, 

50 ) -> AsyncGenerator[AgentChunk, None]: 

51 # Mark start time on first chunk 

52 if self._start_time is None: 

53 self._start_time = datetime.now(timezone.utc) 

54 

55 if record_file: 

56 await record_file.write(chunk.model_dump_json() + "\n") 

57 await record_file.flush() 

58 

59 yield ResponseRawEvent(raw=chunk) 

60 

61 events = self.handle_event(chunk) 

62 for event in events: 

63 yield event 

64 

65 def handle_event(self, event: ResponsesAPIStreamingResponse) -> list[AgentChunk]: # noqa: PLR0911 

66 """Handle individual response events""" 

67 if event.type in ( 

68 ResponsesAPIStreamEvents.RESPONSE_CREATED, 

69 ResponsesAPIStreamEvents.RESPONSE_IN_PROGRESS, 

70 ResponsesAPIStreamEvents.OUTPUT_TEXT_DONE, 

71 ResponsesAPIStreamEvents.CONTENT_PART_DONE, 

72 ): 

73 return [] 

74 

75 if isinstance(event, OutputItemAddedEvent): 

76 self._messages.append(event.item) # type: ignore 

77 return [] 

78 

79 if isinstance(event, ContentPartAddedEvent): 

80 latest_message = self._messages[-1] if self._messages else None 

81 if latest_message and isinstance(latest_message.get("content"), list): 

82 latest_message["content"].append(event.part) 

83 return [] 

84 

85 if isinstance(event, OutputTextDeltaEvent): 

86 # Mark first output time if not already set 

87 if self._first_output_time is None: 

88 self._first_output_time = datetime.now(timezone.utc) 

89 

90 latest_message = self._messages[-1] if self._messages else None 

91 if latest_message and isinstance(latest_message.get("content"), list): 

92 latest_content = latest_message["content"][-1] 

93 if "text" in latest_content: 

94 latest_content["text"] += event.delta 

95 return [ContentDeltaEvent(delta=event.delta)] 

96 return [] 

97 

98 if isinstance(event, OutputItemDoneEvent): 

99 item = event.item 

100 if item.get("type") == "function_call": 

101 return [ 

102 FunctionCallEvent( 

103 call_id=item["call_id"], 

104 name=item["name"], 

105 arguments=item["arguments"], 

106 ), 

107 ] 

108 if item.get("type") == "message": 

109 # Mark output complete time when message is done 

110 if self._output_complete_time is None: 

111 self._output_complete_time = datetime.now(timezone.utc) 

112 

113 content = item.get("content", []) 

114 if content and isinstance(content, list) and len(content) > 0: 

115 end_time = datetime.now(timezone.utc) 

116 latency_ms = TimingMetrics.calculate_latency_ms(self._start_time, self._first_output_time) 

117 output_time_ms = TimingMetrics.calculate_output_time_ms(self._first_output_time, self._output_complete_time) 

118 

119 # Extract model information from event 

120 model_name = getattr(event, "model", None) 

121 # Debug: check if event has model info in different location 

122 if hasattr(event, "response"): 

123 response = getattr(event, "response", None) 

124 if response and hasattr(response, "model"): 

125 model_name = getattr(response, "model", None) 

126 # Create usage information 

127 usage = MessageUsage( 

128 input_tokens=self._usage_data.get("input_tokens"), 

129 output_tokens=self._usage_data.get("output_tokens"), 

130 total_tokens=(self._usage_data.get("input_tokens") or 0) + (self._usage_data.get("output_tokens") or 0), 

131 ) 

132 meta = AssistantMessageMeta( 

133 sent_at=end_time, 

134 model=model_name, 

135 latency_ms=latency_ms, 

136 output_time_ms=output_time_ms, 

137 usage=usage, 

138 ) 

139 return [ 

140 AssistantMessageEvent( 

141 message=NewAssistantMessage(content=[], meta=meta), 

142 ), 

143 ] 

144 

145 elif isinstance(event, FunctionCallArgumentsDeltaEvent): 

146 if self._messages: 

147 latest_message = self._messages[-1] 

148 if latest_message.get("type") == "function_call": 

149 if "arguments" not in latest_message: 

150 latest_message["arguments"] = "" 

151 latest_message["arguments"] += event.delta 

152 return [] 

153 

154 elif isinstance(event, FunctionCallArgumentsDoneEvent): 

155 if self._messages: 

156 latest_message = self._messages[-1] 

157 if latest_message.get("type") == "function_call": 

158 latest_message["arguments"] = event.arguments 

159 return [] 

160 

161 elif isinstance(event, ResponseCompletedEvent): 

162 usage = event.response.usage 

163 if usage: 

164 # Mark usage time 

165 self._usage_time = datetime.now(timezone.utc) 

166 # Store usage data for meta information 

167 self._usage_data["input_tokens"] = usage.input_tokens 

168 self._usage_data["output_tokens"] = usage.output_tokens 

169 # Also store usage time for later calculation 

170 self._usage_data["usage_time"] = self._usage_time 

171 

172 results = [] 

173 

174 # First yield usage event 

175 results.append( 

176 UsageEvent( 

177 usage=EventUsage( 

178 input_tokens=usage.input_tokens, 

179 output_tokens=usage.output_tokens, 

180 ), 

181 ), 

182 ) 

183 

184 # Then yield timing event if we have timing data 

185 latency_ms = TimingMetrics.calculate_latency_ms(self._start_time, self._first_output_time) 

186 output_time_ms = TimingMetrics.calculate_output_time_ms(self._first_output_time, self._output_complete_time) 

187 if latency_ms is not None and output_time_ms is not None: 

188 results.append( 

189 TimingEvent( 

190 timing=Timing( 

191 latency_ms=latency_ms, 

192 output_time_ms=output_time_ms, 

193 ), 

194 ), 

195 ) 

196 

197 return results 

198 

199 return [] 

200 

201 @property 

202 def messages(self) -> list[dict[str, Any]]: 

203 """Get the accumulated messages""" 

204 return self._messages 

205 

206 def reset(self) -> None: 

207 """Reset the processor state""" 

208 self._messages = [] 

209 self._start_time = None 

210 self._first_output_time = None 

211 self._output_complete_time = None 

212 self._usage_time = None 

213 self._usage_data = {}