Coverage for src / dataknobs_bots / knowledge / retrieval / formatter.py: 22%
86 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-16 10:13 -0700
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-16 10:13 -0700
1"""Context formatting utilities for RAG retrieval.
3This module provides formatting for retrieved chunks to optimize
4LLM context window usage and improve comprehension.
5"""
7from __future__ import annotations
9from dataclasses import dataclass
10from typing import Any
12from dataknobs_bots.knowledge.retrieval.merger import MergedChunk
15@dataclass
16class FormatterConfig:
17 """Configuration for context formatting.
19 Attributes:
20 small_chunk_threshold: Max chars for "small" chunks (full heading path)
21 medium_chunk_threshold: Max chars for "medium" chunks (last 2 headings)
22 include_scores: Whether to include similarity scores
23 include_source: Whether to include source file information
24 group_by_source: Whether to group chunks by source file
25 """
27 small_chunk_threshold: int = 200
28 medium_chunk_threshold: int = 800
29 include_scores: bool = False
30 include_source: bool = True
31 group_by_source: bool = False
34class ContextFormatter:
35 """Formats retrieved chunks for LLM context with dynamic heading inclusion.
37 This formatter applies intelligent heading inclusion based on content
38 size to optimize token usage while maintaining context clarity:
39 - Small chunks: Full heading path (need context)
40 - Medium chunks: Last 2 heading levels
41 - Large chunks: No headings (content is self-contained)
43 Example:
44 ```python
45 formatter = ContextFormatter(FormatterConfig(
46 small_chunk_threshold=200,
47 include_scores=True
48 ))
50 # Format standard results
51 context = formatter.format(results)
53 # Format merged chunks
54 context = formatter.format_merged(merged_chunks)
55 ```
56 """
58 def __init__(self, config: FormatterConfig | None = None):
59 """Initialize the context formatter.
61 Args:
62 config: Formatter configuration, uses defaults if not provided
63 """
64 self.config = config or FormatterConfig()
66 def format(self, results: list[dict[str, Any]]) -> str:
67 """Format search results for LLM context.
69 Args:
70 results: Search results from RAGKnowledgeBase.query()
72 Returns:
73 Formatted context string
74 """
75 if not results:
76 return ""
78 if self.config.group_by_source:
79 return self._format_grouped_by_source(results)
81 formatted_chunks = []
82 for i, result in enumerate(results, 1):
83 formatted = self._format_result(result, i)
84 formatted_chunks.append(formatted)
86 return "\n\n---\n\n".join(formatted_chunks)
88 def format_merged(self, merged_chunks: list[MergedChunk]) -> str:
89 """Format merged chunks for LLM context.
91 Args:
92 merged_chunks: Merged chunks from ChunkMerger
94 Returns:
95 Formatted context string
96 """
97 if not merged_chunks:
98 return ""
100 # Convert to result format and use standard formatting
101 results = []
102 for chunk in merged_chunks:
103 results.append({
104 "text": chunk.text,
105 "source": chunk.source,
106 "heading_path": chunk.heading_display,
107 "similarity": chunk.avg_similarity,
108 "metadata": {
109 "headings": chunk.heading_path,
110 "content_length": chunk.content_length,
111 },
112 })
114 return self.format(results)
116 def _format_result(self, result: dict[str, Any], index: int) -> str:
117 """Format a single result with dynamic heading inclusion.
119 Args:
120 result: Search result dictionary
121 index: Result index for numbering
123 Returns:
124 Formatted chunk string
125 """
126 text = result.get("text", "")
127 source = result.get("source", "")
128 similarity = result.get("similarity", 0.0)
129 metadata = result.get("metadata", {})
131 # Get heading information
132 headings = metadata.get("headings", [])
133 if not headings:
134 heading_path = result.get("heading_path", "")
135 if isinstance(heading_path, str) and heading_path:
136 headings = heading_path.split(" > ")
138 # Determine content length for heading decision
139 content_length = metadata.get("content_length", len(text))
141 # Get headings to display based on content size
142 display_headings = self._get_display_headings(headings, content_length)
144 # Build formatted chunk
145 lines = []
147 # Add index and heading
148 if display_headings:
149 heading_str = " > ".join(display_headings)
150 if self.config.include_scores:
151 lines.append(f"[{index}] [{similarity:.2f}] {heading_str}")
152 else:
153 lines.append(f"[{index}] {heading_str}")
154 else:
155 if self.config.include_scores:
156 lines.append(f"[{index}] [{similarity:.2f}]")
157 else:
158 lines.append(f"[{index}]")
160 # Add content
161 lines.append(text.strip())
163 # Add source
164 if self.config.include_source and source:
165 lines.append(f"(Source: {source})")
167 return "\n".join(lines)
169 def _get_display_headings(
170 self,
171 headings: list[str],
172 content_length: int,
173 ) -> list[str]:
174 """Get headings to display based on content length.
176 Implements dynamic heading inclusion:
177 - Small chunks: Full heading path
178 - Medium chunks: Last 2 heading levels
179 - Large chunks: No headings
181 Args:
182 headings: Full heading path
183 content_length: Length of content in characters
185 Returns:
186 List of headings to display
187 """
188 if not headings:
189 return []
191 if content_length < self.config.small_chunk_threshold:
192 # Small chunks: include full heading path
193 return headings
194 elif content_length < self.config.medium_chunk_threshold:
195 # Medium chunks: include last 2 heading levels
196 return headings[-2:] if len(headings) > 2 else headings
197 else:
198 # Large chunks: omit headings (content is self-contained)
199 return []
201 def _format_grouped_by_source(self, results: list[dict[str, Any]]) -> str:
202 """Format results grouped by source file.
204 Args:
205 results: Search results
207 Returns:
208 Formatted context string with source grouping
209 """
210 from collections import defaultdict
212 # Group by source
213 groups: dict[str, list[dict[str, Any]]] = defaultdict(list)
214 for result in results:
215 source = result.get("source", "unknown")
216 groups[source].append(result)
218 # Format each group
219 formatted_groups = []
220 chunk_index = 1
222 for source, source_results in groups.items():
223 group_lines = [f"## Source: {source}"]
225 for result in source_results:
226 formatted = self._format_result(result, chunk_index)
227 # Remove source line since we're grouping
228 lines = formatted.split("\n")
229 lines = [line for line in lines if not line.startswith("(Source:")]
230 group_lines.append("\n".join(lines))
231 chunk_index += 1
233 formatted_groups.append("\n\n".join(group_lines))
235 return "\n\n---\n\n".join(formatted_groups)
237 def wrap_for_prompt(self, context: str, tag: str = "knowledge_base") -> str:
238 """Wrap formatted context in XML tags for prompt injection.
240 Args:
241 context: Formatted context string
242 tag: Tag name to wrap with
244 Returns:
245 Context wrapped in XML tags
246 """
247 if not context:
248 return ""
249 return f"<{tag}>\n{context}\n</{tag}>"