Coverage for src / dataknobs_bots / middleware / cost.py: 17%

124 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-16 10:13 -0700

1"""Cost tracking middleware for monitoring LLM usage.""" 

2 

3import json 

4import logging 

5from typing import Any 

6 

7from dataknobs_bots.bot.context import BotContext 

8 

9from .base import Middleware 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14class CostTrackingMiddleware(Middleware): 

15 """Middleware for tracking LLM API costs and usage. 

16 

17 Monitors token usage across different providers (Ollama, OpenAI, Anthropic, etc.) 

18 to help optimize costs and track budgets. 

19 

20 Attributes: 

21 track_tokens: Whether to track token usage 

22 cost_rates: Token cost rates per provider/model 

23 usage_stats: Accumulated usage statistics by client_id 

24 

25 Example: 

26 ```python 

27 # Create middleware with default rates 

28 middleware = CostTrackingMiddleware() 

29 

30 # Or with custom rates 

31 middleware = CostTrackingMiddleware( 

32 cost_rates={ 

33 "openai": { 

34 "gpt-4o": {"input": 0.0025, "output": 0.01}, 

35 }, 

36 } 

37 ) 

38 

39 # Get stats 

40 stats = middleware.get_client_stats("my-client") 

41 total = middleware.get_total_cost() 

42 

43 # Export to JSON 

44 json_data = middleware.export_stats_json() 

45 ``` 

46 """ 

47 

48 # Default cost rates (USD per 1K tokens) - Updated Dec 2024 

49 DEFAULT_RATES: dict[str, Any] = { 

50 "ollama": {"input": 0.0, "output": 0.0}, # Free (infrastructure cost only) 

51 "openai": { 

52 "gpt-4o": {"input": 0.0025, "output": 0.01}, 

53 "gpt-4o-mini": {"input": 0.00015, "output": 0.0006}, 

54 "gpt-4-turbo": {"input": 0.01, "output": 0.03}, 

55 "gpt-4": {"input": 0.03, "output": 0.06}, 

56 "gpt-3.5-turbo": {"input": 0.0005, "output": 0.0015}, 

57 "o1": {"input": 0.015, "output": 0.06}, 

58 "o1-mini": {"input": 0.003, "output": 0.012}, 

59 }, 

60 "anthropic": { 

61 "claude-3-5-sonnet": {"input": 0.003, "output": 0.015}, 

62 "claude-3-5-haiku": {"input": 0.0008, "output": 0.004}, 

63 "claude-3-opus": {"input": 0.015, "output": 0.075}, 

64 "claude-3-sonnet": {"input": 0.003, "output": 0.015}, 

65 "claude-3-haiku": {"input": 0.00025, "output": 0.00125}, 

66 }, 

67 "google": { 

68 "gemini-1.5-pro": {"input": 0.00125, "output": 0.005}, 

69 "gemini-1.5-flash": {"input": 0.000075, "output": 0.0003}, 

70 "gemini-2.0-flash": {"input": 0.0001, "output": 0.0004}, 

71 }, 

72 } 

73 

74 def __init__( 

75 self, 

76 track_tokens: bool = True, 

77 cost_rates: dict[str, Any] | None = None, 

78 ): 

79 """Initialize cost tracking middleware. 

80 

81 Args: 

82 track_tokens: Enable token tracking 

83 cost_rates: Optional custom cost rates (merged with defaults) 

84 """ 

85 self.track_tokens = track_tokens 

86 # Merge custom rates with defaults 

87 self.cost_rates = self.DEFAULT_RATES.copy() 

88 if cost_rates: 

89 for provider, rates in cost_rates.items(): 

90 if provider in self.cost_rates: 

91 if isinstance(rates, dict) and isinstance( 

92 self.cost_rates[provider], dict 

93 ): 

94 self.cost_rates[provider].update(rates) 

95 else: 

96 self.cost_rates[provider] = rates 

97 else: 

98 self.cost_rates[provider] = rates 

99 

100 self._usage_stats: dict[str, dict[str, Any]] = {} 

101 self._logger = logging.getLogger(f"{__name__}.CostTracker") 

102 

103 async def before_message(self, message: str, context: BotContext) -> None: 

104 """Track message before processing (mainly for logging). 

105 

106 Args: 

107 message: User's input message 

108 context: Bot context 

109 """ 

110 # Estimate input tokens (rough approximation: ~4 chars per token) 

111 estimated_tokens = len(message) // 4 

112 self._logger.debug(f"Estimated input tokens: {estimated_tokens}") 

113 

114 async def after_message( 

115 self, response: str, context: BotContext, **kwargs: Any 

116 ) -> None: 

117 """Track costs after bot response. 

118 

119 Args: 

120 response: Bot's generated response 

121 context: Bot context 

122 **kwargs: Should contain 'tokens_used', 'provider', 'model' if available 

123 """ 

124 if not self.track_tokens: 

125 return 

126 

127 client_id = context.client_id 

128 

129 # Extract provider and model info 

130 provider = kwargs.get("provider", "unknown") 

131 model = kwargs.get("model", "unknown") 

132 

133 # Get token counts 

134 tokens_used = kwargs.get("tokens_used", {}) 

135 if isinstance(tokens_used, int): 

136 # If single number, assume it's total and estimate split 

137 input_tokens = len(context.session_metadata.get("last_message", "")) // 4 

138 output_tokens = tokens_used - input_tokens 

139 else: 

140 input_tokens = int( 

141 tokens_used.get( 

142 "input", 

143 tokens_used.get( 

144 "prompt_tokens", 

145 len(context.session_metadata.get("last_message", "")) // 4, 

146 ), 

147 ) 

148 ) 

149 output_tokens = int( 

150 tokens_used.get( 

151 "output", 

152 tokens_used.get("completion_tokens", len(response) // 4), 

153 ) 

154 ) 

155 

156 # Calculate cost 

157 cost = self._calculate_cost(provider, model, input_tokens, output_tokens) 

158 

159 # Update stats 

160 if client_id not in self._usage_stats: 

161 self._usage_stats[client_id] = { 

162 "client_id": client_id, 

163 "total_requests": 0, 

164 "total_input_tokens": 0, 

165 "total_output_tokens": 0, 

166 "total_cost_usd": 0.0, 

167 "by_provider": {}, 

168 } 

169 

170 stats = self._usage_stats[client_id] 

171 stats["total_requests"] += 1 

172 stats["total_input_tokens"] += input_tokens 

173 stats["total_output_tokens"] += output_tokens 

174 stats["total_cost_usd"] += cost 

175 

176 # Track by provider 

177 if provider not in stats["by_provider"]: 

178 stats["by_provider"][provider] = { 

179 "requests": 0, 

180 "input_tokens": 0, 

181 "output_tokens": 0, 

182 "cost_usd": 0.0, 

183 "by_model": {}, 

184 } 

185 

186 provider_stats = stats["by_provider"][provider] 

187 provider_stats["requests"] += 1 

188 provider_stats["input_tokens"] += input_tokens 

189 provider_stats["output_tokens"] += output_tokens 

190 provider_stats["cost_usd"] += cost 

191 

192 # Track by model within provider 

193 if model not in provider_stats["by_model"]: 

194 provider_stats["by_model"][model] = { 

195 "requests": 0, 

196 "input_tokens": 0, 

197 "output_tokens": 0, 

198 "cost_usd": 0.0, 

199 } 

200 

201 model_stats = provider_stats["by_model"][model] 

202 model_stats["requests"] += 1 

203 model_stats["input_tokens"] += input_tokens 

204 model_stats["output_tokens"] += output_tokens 

205 model_stats["cost_usd"] += cost 

206 

207 self._logger.info( 

208 f"Client {client_id}: {provider}/{model} - " 

209 f"{input_tokens} in + {output_tokens} out tokens, " 

210 f"cost: ${cost:.6f}, total: ${stats['total_cost_usd']:.6f}" 

211 ) 

212 

213 async def post_stream( 

214 self, message: str, response: str, context: BotContext 

215 ) -> None: 

216 """Track costs after streaming completes. 

217 

218 For streaming responses, token counts are estimated from text length 

219 since exact counts may not be available until the stream completes. 

220 

221 Args: 

222 message: Original user message 

223 response: Complete accumulated response from streaming 

224 context: Bot context 

225 """ 

226 if not self.track_tokens: 

227 return 

228 

229 client_id = context.client_id 

230 

231 # For streaming, we estimate tokens from text length (~4 chars per token) 

232 input_tokens = len(message) // 4 

233 output_tokens = len(response) // 4 

234 

235 # Get provider/model from context metadata if available 

236 provider = context.session_metadata.get("provider", "unknown") 

237 model = context.session_metadata.get("model", "unknown") 

238 

239 # Calculate cost 

240 cost = self._calculate_cost(provider, model, input_tokens, output_tokens) 

241 

242 # Update stats 

243 if client_id not in self._usage_stats: 

244 self._usage_stats[client_id] = { 

245 "total_requests": 0, 

246 "total_input_tokens": 0, 

247 "total_output_tokens": 0, 

248 "total_cost_usd": 0.0, 

249 "by_provider": {}, 

250 } 

251 

252 stats = self._usage_stats[client_id] 

253 stats["total_requests"] += 1 

254 stats["total_input_tokens"] += input_tokens 

255 stats["total_output_tokens"] += output_tokens 

256 stats["total_cost_usd"] += cost 

257 

258 # Track by provider 

259 if provider not in stats["by_provider"]: 

260 stats["by_provider"][provider] = { 

261 "requests": 0, 

262 "input_tokens": 0, 

263 "output_tokens": 0, 

264 "cost_usd": 0.0, 

265 } 

266 

267 provider_stats = stats["by_provider"][provider] 

268 provider_stats["requests"] += 1 

269 provider_stats["input_tokens"] += input_tokens 

270 provider_stats["output_tokens"] += output_tokens 

271 provider_stats["cost_usd"] += cost 

272 

273 self._logger.info( 

274 f"Stream complete - Client {client_id}: {provider}/{model} - " 

275 f"~{input_tokens} in + ~{output_tokens} out tokens (estimated), " 

276 f"cost: ${cost:.6f}, total: ${stats['total_cost_usd']:.6f}" 

277 ) 

278 

279 async def on_error( 

280 self, error: Exception, message: str, context: BotContext 

281 ) -> None: 

282 """Log errors but don't track costs for failed requests. 

283 

284 Args: 

285 error: The exception that occurred 

286 message: User message that caused the error 

287 context: Bot context 

288 """ 

289 self._logger.warning( 

290 f"Error during request for client {context.client_id}: {error}" 

291 ) 

292 

293 def _calculate_cost( 

294 self, provider: str, model: str, input_tokens: int, output_tokens: int 

295 ) -> float: 

296 """Calculate cost for token usage. 

297 

298 Args: 

299 provider: LLM provider name 

300 model: Model name 

301 input_tokens: Number of input tokens 

302 output_tokens: Number of output tokens 

303 

304 Returns: 

305 Cost in USD 

306 """ 

307 # Get rates for provider/model 

308 if provider in self.cost_rates: 

309 provider_rates = self.cost_rates[provider] 

310 

311 if isinstance(provider_rates, dict): 

312 # Check if model-specific rates exist 

313 if model in provider_rates: 

314 rates = provider_rates[model] 

315 elif "input" in provider_rates: 

316 # Use generic rates for provider (e.g., ollama) 

317 rates = provider_rates 

318 else: 

319 # Try partial model name match 

320 for model_key in provider_rates: 

321 if model_key in model or model in model_key: 

322 rates = provider_rates[model_key] 

323 break 

324 else: 

325 return 0.0 

326 else: 

327 return 0.0 

328 

329 # Calculate cost (rates are per 1K tokens) 

330 input_cost = (input_tokens / 1000) * float(rates.get("input", 0.0)) 

331 output_cost = (output_tokens / 1000) * float(rates.get("output", 0.0)) 

332 return float(input_cost + output_cost) 

333 

334 return 0.0 

335 

336 def get_client_stats(self, client_id: str) -> dict[str, Any] | None: 

337 """Get usage statistics for a client. 

338 

339 Args: 

340 client_id: Client identifier 

341 

342 Returns: 

343 Usage statistics or None if not found 

344 """ 

345 return self._usage_stats.get(client_id) 

346 

347 def get_all_stats(self) -> dict[str, dict[str, Any]]: 

348 """Get all usage statistics. 

349 

350 Returns: 

351 Dictionary mapping client_id to statistics 

352 """ 

353 return self._usage_stats.copy() 

354 

355 def get_total_cost(self) -> float: 

356 """Get total cost across all clients. 

357 

358 Returns: 

359 Total cost in USD 

360 """ 

361 return float( 

362 sum(stats["total_cost_usd"] for stats in self._usage_stats.values()) 

363 ) 

364 

365 def get_total_tokens(self) -> dict[str, int]: 

366 """Get total tokens across all clients. 

367 

368 Returns: 

369 Dictionary with 'input', 'output', and 'total' token counts 

370 """ 

371 input_tokens = sum( 

372 stats["total_input_tokens"] for stats in self._usage_stats.values() 

373 ) 

374 output_tokens = sum( 

375 stats["total_output_tokens"] for stats in self._usage_stats.values() 

376 ) 

377 return { 

378 "input": input_tokens, 

379 "output": output_tokens, 

380 "total": input_tokens + output_tokens, 

381 } 

382 

383 def clear_stats(self, client_id: str | None = None) -> None: 

384 """Clear usage statistics. 

385 

386 Args: 

387 client_id: If provided, clear only this client. Otherwise clear all. 

388 """ 

389 if client_id: 

390 if client_id in self._usage_stats: 

391 del self._usage_stats[client_id] 

392 else: 

393 self._usage_stats.clear() 

394 

395 def export_stats_json(self, indent: int = 2) -> str: 

396 """Export all statistics as JSON. 

397 

398 Args: 

399 indent: JSON indentation level 

400 

401 Returns: 

402 JSON string of all statistics 

403 """ 

404 return json.dumps(self._usage_stats, indent=indent) 

405 

406 def export_stats_csv(self) -> str: 

407 """Export statistics as CSV (one row per client). 

408 

409 Returns: 

410 CSV string with headers 

411 """ 

412 lines = [ 

413 "client_id,total_requests,total_input_tokens,total_output_tokens,total_cost_usd" 

414 ] 

415 for client_id, stats in self._usage_stats.items(): 

416 lines.append( 

417 f"{client_id},{stats['total_requests']}," 

418 f"{stats['total_input_tokens']},{stats['total_output_tokens']}," 

419 f"{stats['total_cost_usd']:.6f}" 

420 ) 

421 return "\n".join(lines)