Coverage for src / dataknobs_bots / middleware / cost.py: 17%
124 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-16 10:13 -0700
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-16 10:13 -0700
1"""Cost tracking middleware for monitoring LLM usage."""
3import json
4import logging
5from typing import Any
7from dataknobs_bots.bot.context import BotContext
9from .base import Middleware
11logger = logging.getLogger(__name__)
14class CostTrackingMiddleware(Middleware):
15 """Middleware for tracking LLM API costs and usage.
17 Monitors token usage across different providers (Ollama, OpenAI, Anthropic, etc.)
18 to help optimize costs and track budgets.
20 Attributes:
21 track_tokens: Whether to track token usage
22 cost_rates: Token cost rates per provider/model
23 usage_stats: Accumulated usage statistics by client_id
25 Example:
26 ```python
27 # Create middleware with default rates
28 middleware = CostTrackingMiddleware()
30 # Or with custom rates
31 middleware = CostTrackingMiddleware(
32 cost_rates={
33 "openai": {
34 "gpt-4o": {"input": 0.0025, "output": 0.01},
35 },
36 }
37 )
39 # Get stats
40 stats = middleware.get_client_stats("my-client")
41 total = middleware.get_total_cost()
43 # Export to JSON
44 json_data = middleware.export_stats_json()
45 ```
46 """
48 # Default cost rates (USD per 1K tokens) - Updated Dec 2024
49 DEFAULT_RATES: dict[str, Any] = {
50 "ollama": {"input": 0.0, "output": 0.0}, # Free (infrastructure cost only)
51 "openai": {
52 "gpt-4o": {"input": 0.0025, "output": 0.01},
53 "gpt-4o-mini": {"input": 0.00015, "output": 0.0006},
54 "gpt-4-turbo": {"input": 0.01, "output": 0.03},
55 "gpt-4": {"input": 0.03, "output": 0.06},
56 "gpt-3.5-turbo": {"input": 0.0005, "output": 0.0015},
57 "o1": {"input": 0.015, "output": 0.06},
58 "o1-mini": {"input": 0.003, "output": 0.012},
59 },
60 "anthropic": {
61 "claude-3-5-sonnet": {"input": 0.003, "output": 0.015},
62 "claude-3-5-haiku": {"input": 0.0008, "output": 0.004},
63 "claude-3-opus": {"input": 0.015, "output": 0.075},
64 "claude-3-sonnet": {"input": 0.003, "output": 0.015},
65 "claude-3-haiku": {"input": 0.00025, "output": 0.00125},
66 },
67 "google": {
68 "gemini-1.5-pro": {"input": 0.00125, "output": 0.005},
69 "gemini-1.5-flash": {"input": 0.000075, "output": 0.0003},
70 "gemini-2.0-flash": {"input": 0.0001, "output": 0.0004},
71 },
72 }
74 def __init__(
75 self,
76 track_tokens: bool = True,
77 cost_rates: dict[str, Any] | None = None,
78 ):
79 """Initialize cost tracking middleware.
81 Args:
82 track_tokens: Enable token tracking
83 cost_rates: Optional custom cost rates (merged with defaults)
84 """
85 self.track_tokens = track_tokens
86 # Merge custom rates with defaults
87 self.cost_rates = self.DEFAULT_RATES.copy()
88 if cost_rates:
89 for provider, rates in cost_rates.items():
90 if provider in self.cost_rates:
91 if isinstance(rates, dict) and isinstance(
92 self.cost_rates[provider], dict
93 ):
94 self.cost_rates[provider].update(rates)
95 else:
96 self.cost_rates[provider] = rates
97 else:
98 self.cost_rates[provider] = rates
100 self._usage_stats: dict[str, dict[str, Any]] = {}
101 self._logger = logging.getLogger(f"{__name__}.CostTracker")
103 async def before_message(self, message: str, context: BotContext) -> None:
104 """Track message before processing (mainly for logging).
106 Args:
107 message: User's input message
108 context: Bot context
109 """
110 # Estimate input tokens (rough approximation: ~4 chars per token)
111 estimated_tokens = len(message) // 4
112 self._logger.debug(f"Estimated input tokens: {estimated_tokens}")
114 async def after_message(
115 self, response: str, context: BotContext, **kwargs: Any
116 ) -> None:
117 """Track costs after bot response.
119 Args:
120 response: Bot's generated response
121 context: Bot context
122 **kwargs: Should contain 'tokens_used', 'provider', 'model' if available
123 """
124 if not self.track_tokens:
125 return
127 client_id = context.client_id
129 # Extract provider and model info
130 provider = kwargs.get("provider", "unknown")
131 model = kwargs.get("model", "unknown")
133 # Get token counts
134 tokens_used = kwargs.get("tokens_used", {})
135 if isinstance(tokens_used, int):
136 # If single number, assume it's total and estimate split
137 input_tokens = len(context.session_metadata.get("last_message", "")) // 4
138 output_tokens = tokens_used - input_tokens
139 else:
140 input_tokens = int(
141 tokens_used.get(
142 "input",
143 tokens_used.get(
144 "prompt_tokens",
145 len(context.session_metadata.get("last_message", "")) // 4,
146 ),
147 )
148 )
149 output_tokens = int(
150 tokens_used.get(
151 "output",
152 tokens_used.get("completion_tokens", len(response) // 4),
153 )
154 )
156 # Calculate cost
157 cost = self._calculate_cost(provider, model, input_tokens, output_tokens)
159 # Update stats
160 if client_id not in self._usage_stats:
161 self._usage_stats[client_id] = {
162 "client_id": client_id,
163 "total_requests": 0,
164 "total_input_tokens": 0,
165 "total_output_tokens": 0,
166 "total_cost_usd": 0.0,
167 "by_provider": {},
168 }
170 stats = self._usage_stats[client_id]
171 stats["total_requests"] += 1
172 stats["total_input_tokens"] += input_tokens
173 stats["total_output_tokens"] += output_tokens
174 stats["total_cost_usd"] += cost
176 # Track by provider
177 if provider not in stats["by_provider"]:
178 stats["by_provider"][provider] = {
179 "requests": 0,
180 "input_tokens": 0,
181 "output_tokens": 0,
182 "cost_usd": 0.0,
183 "by_model": {},
184 }
186 provider_stats = stats["by_provider"][provider]
187 provider_stats["requests"] += 1
188 provider_stats["input_tokens"] += input_tokens
189 provider_stats["output_tokens"] += output_tokens
190 provider_stats["cost_usd"] += cost
192 # Track by model within provider
193 if model not in provider_stats["by_model"]:
194 provider_stats["by_model"][model] = {
195 "requests": 0,
196 "input_tokens": 0,
197 "output_tokens": 0,
198 "cost_usd": 0.0,
199 }
201 model_stats = provider_stats["by_model"][model]
202 model_stats["requests"] += 1
203 model_stats["input_tokens"] += input_tokens
204 model_stats["output_tokens"] += output_tokens
205 model_stats["cost_usd"] += cost
207 self._logger.info(
208 f"Client {client_id}: {provider}/{model} - "
209 f"{input_tokens} in + {output_tokens} out tokens, "
210 f"cost: ${cost:.6f}, total: ${stats['total_cost_usd']:.6f}"
211 )
213 async def post_stream(
214 self, message: str, response: str, context: BotContext
215 ) -> None:
216 """Track costs after streaming completes.
218 For streaming responses, token counts are estimated from text length
219 since exact counts may not be available until the stream completes.
221 Args:
222 message: Original user message
223 response: Complete accumulated response from streaming
224 context: Bot context
225 """
226 if not self.track_tokens:
227 return
229 client_id = context.client_id
231 # For streaming, we estimate tokens from text length (~4 chars per token)
232 input_tokens = len(message) // 4
233 output_tokens = len(response) // 4
235 # Get provider/model from context metadata if available
236 provider = context.session_metadata.get("provider", "unknown")
237 model = context.session_metadata.get("model", "unknown")
239 # Calculate cost
240 cost = self._calculate_cost(provider, model, input_tokens, output_tokens)
242 # Update stats
243 if client_id not in self._usage_stats:
244 self._usage_stats[client_id] = {
245 "total_requests": 0,
246 "total_input_tokens": 0,
247 "total_output_tokens": 0,
248 "total_cost_usd": 0.0,
249 "by_provider": {},
250 }
252 stats = self._usage_stats[client_id]
253 stats["total_requests"] += 1
254 stats["total_input_tokens"] += input_tokens
255 stats["total_output_tokens"] += output_tokens
256 stats["total_cost_usd"] += cost
258 # Track by provider
259 if provider not in stats["by_provider"]:
260 stats["by_provider"][provider] = {
261 "requests": 0,
262 "input_tokens": 0,
263 "output_tokens": 0,
264 "cost_usd": 0.0,
265 }
267 provider_stats = stats["by_provider"][provider]
268 provider_stats["requests"] += 1
269 provider_stats["input_tokens"] += input_tokens
270 provider_stats["output_tokens"] += output_tokens
271 provider_stats["cost_usd"] += cost
273 self._logger.info(
274 f"Stream complete - Client {client_id}: {provider}/{model} - "
275 f"~{input_tokens} in + ~{output_tokens} out tokens (estimated), "
276 f"cost: ${cost:.6f}, total: ${stats['total_cost_usd']:.6f}"
277 )
279 async def on_error(
280 self, error: Exception, message: str, context: BotContext
281 ) -> None:
282 """Log errors but don't track costs for failed requests.
284 Args:
285 error: The exception that occurred
286 message: User message that caused the error
287 context: Bot context
288 """
289 self._logger.warning(
290 f"Error during request for client {context.client_id}: {error}"
291 )
293 def _calculate_cost(
294 self, provider: str, model: str, input_tokens: int, output_tokens: int
295 ) -> float:
296 """Calculate cost for token usage.
298 Args:
299 provider: LLM provider name
300 model: Model name
301 input_tokens: Number of input tokens
302 output_tokens: Number of output tokens
304 Returns:
305 Cost in USD
306 """
307 # Get rates for provider/model
308 if provider in self.cost_rates:
309 provider_rates = self.cost_rates[provider]
311 if isinstance(provider_rates, dict):
312 # Check if model-specific rates exist
313 if model in provider_rates:
314 rates = provider_rates[model]
315 elif "input" in provider_rates:
316 # Use generic rates for provider (e.g., ollama)
317 rates = provider_rates
318 else:
319 # Try partial model name match
320 for model_key in provider_rates:
321 if model_key in model or model in model_key:
322 rates = provider_rates[model_key]
323 break
324 else:
325 return 0.0
326 else:
327 return 0.0
329 # Calculate cost (rates are per 1K tokens)
330 input_cost = (input_tokens / 1000) * float(rates.get("input", 0.0))
331 output_cost = (output_tokens / 1000) * float(rates.get("output", 0.0))
332 return float(input_cost + output_cost)
334 return 0.0
336 def get_client_stats(self, client_id: str) -> dict[str, Any] | None:
337 """Get usage statistics for a client.
339 Args:
340 client_id: Client identifier
342 Returns:
343 Usage statistics or None if not found
344 """
345 return self._usage_stats.get(client_id)
347 def get_all_stats(self) -> dict[str, dict[str, Any]]:
348 """Get all usage statistics.
350 Returns:
351 Dictionary mapping client_id to statistics
352 """
353 return self._usage_stats.copy()
355 def get_total_cost(self) -> float:
356 """Get total cost across all clients.
358 Returns:
359 Total cost in USD
360 """
361 return float(
362 sum(stats["total_cost_usd"] for stats in self._usage_stats.values())
363 )
365 def get_total_tokens(self) -> dict[str, int]:
366 """Get total tokens across all clients.
368 Returns:
369 Dictionary with 'input', 'output', and 'total' token counts
370 """
371 input_tokens = sum(
372 stats["total_input_tokens"] for stats in self._usage_stats.values()
373 )
374 output_tokens = sum(
375 stats["total_output_tokens"] for stats in self._usage_stats.values()
376 )
377 return {
378 "input": input_tokens,
379 "output": output_tokens,
380 "total": input_tokens + output_tokens,
381 }
383 def clear_stats(self, client_id: str | None = None) -> None:
384 """Clear usage statistics.
386 Args:
387 client_id: If provided, clear only this client. Otherwise clear all.
388 """
389 if client_id:
390 if client_id in self._usage_stats:
391 del self._usage_stats[client_id]
392 else:
393 self._usage_stats.clear()
395 def export_stats_json(self, indent: int = 2) -> str:
396 """Export all statistics as JSON.
398 Args:
399 indent: JSON indentation level
401 Returns:
402 JSON string of all statistics
403 """
404 return json.dumps(self._usage_stats, indent=indent)
406 def export_stats_csv(self) -> str:
407 """Export statistics as CSV (one row per client).
409 Returns:
410 CSV string with headers
411 """
412 lines = [
413 "client_id,total_requests,total_input_tokens,total_output_tokens,total_cost_usd"
414 ]
415 for client_id, stats in self._usage_stats.items():
416 lines.append(
417 f"{client_id},{stats['total_requests']},"
418 f"{stats['total_input_tokens']},{stats['total_output_tokens']},"
419 f"{stats['total_cost_usd']:.6f}"
420 )
421 return "\n".join(lines)