Coverage for src / dataknobs_bots / knowledge / query / transformer.py: 27%

74 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-16 10:13 -0700

1"""Query transformation using LLM for improved retrieval. 

2 

3This module provides LLM-based query transformation to generate 

4optimized search queries from user input. 

5""" 

6 

7from __future__ import annotations 

8 

9from dataclasses import dataclass 

10from typing import Any 

11 

12 

13@dataclass 

14class TransformerConfig: 

15 """Configuration for query transformation. 

16 

17 Attributes: 

18 enabled: Whether transformation is enabled 

19 llm_provider: LLM provider name (e.g., "ollama", "openai") 

20 llm_model: Model to use for transformation 

21 num_queries: Number of alternative queries to generate 

22 domain_context: Domain-specific context for better queries 

23 """ 

24 

25 enabled: bool = False 

26 llm_provider: str = "ollama" 

27 llm_model: str = "llama3.2" 

28 num_queries: int = 3 

29 domain_context: str = "" 

30 

31 

32class QueryTransformer: 

33 """LLM-based query transformation for improved RAG retrieval. 

34 

35 Transforms user input into optimized search queries by using an LLM 

36 to extract key concepts and generate alternative phrasings. 

37 

38 This is particularly useful when: 

39 - User input contains literal text to analyze (not queries) 

40 - User asks vague questions that need expansion 

41 - Domain-specific terminology needs translation 

42 

43 Example: 

44 ```python 

45 config = TransformerConfig( 

46 enabled=True, 

47 llm_provider="ollama", 

48 llm_model="llama3.2", 

49 domain_context="prompt engineering" 

50 ) 

51 transformer = QueryTransformer(config) 

52 await transformer.initialize() 

53 

54 # Transform user input to search queries 

55 queries = await transformer.transform( 

56 "Analyze this: Write a poem about cats" 

57 ) 

58 # Returns: ["prompt analysis techniques", "evaluating prompt quality", ...] 

59 ``` 

60 """ 

61 

62 def __init__(self, config: TransformerConfig | None = None): 

63 """Initialize the query transformer. 

64 

65 Args: 

66 config: Transformer configuration, uses defaults if not provided 

67 """ 

68 self.config = config or TransformerConfig() 

69 self._llm = None 

70 self._initialized = False 

71 

72 async def initialize(self) -> None: 

73 """Initialize the LLM provider. 

74 

75 Must be called before using transform() if enabled. 

76 """ 

77 if not self.config.enabled: 

78 return 

79 

80 from dataknobs_llm.llm import LLMProviderFactory 

81 

82 factory = LLMProviderFactory(is_async=True) 

83 self._llm = factory.create({ 

84 "provider": self.config.llm_provider, 

85 "model": self.config.llm_model, 

86 }) 

87 await self._llm.initialize() 

88 self._initialized = True 

89 

90 async def close(self) -> None: 

91 """Close the LLM provider and release resources.""" 

92 if self._llm and hasattr(self._llm, "close"): 

93 await self._llm.close() 

94 self._initialized = False 

95 

96 async def transform( 

97 self, 

98 user_input: str, 

99 num_queries: int | None = None, 

100 ) -> list[str]: 

101 """Transform user input into optimized search queries. 

102 

103 Args: 

104 user_input: The user's message or question 

105 num_queries: Number of queries to generate (overrides config) 

106 

107 Returns: 

108 List of optimized search queries 

109 

110 Raises: 

111 RuntimeError: If transformer is enabled but not initialized 

112 """ 

113 # If disabled, return the original input as a single query 

114 if not self.config.enabled: 

115 return [user_input] 

116 

117 if not self._initialized: 

118 raise RuntimeError( 

119 "QueryTransformer not initialized. Call initialize() first." 

120 ) 

121 

122 num = num_queries or self.config.num_queries 

123 

124 # Build the transformation prompt 

125 prompt = self._build_prompt(user_input, num) 

126 

127 # Generate queries using LLM 

128 response = await self._llm.generate(prompt) 

129 

130 # Parse the response into individual queries 

131 queries = self._parse_response(response, user_input) 

132 

133 return queries[:num] 

134 

135 def _build_prompt(self, user_input: str, num_queries: int) -> str: 

136 """Build the transformation prompt. 

137 

138 Args: 

139 user_input: User's message 

140 num_queries: Number of queries to generate 

141 

142 Returns: 

143 Prompt string for LLM 

144 """ 

145 domain_context = "" 

146 if self.config.domain_context: 

147 domain_context = f" in the context of {self.config.domain_context}" 

148 

149 return f"""Generate {num_queries} search queries to find relevant knowledge base content for the following user message{domain_context}. 

150 

151User message: "{user_input}" 

152 

153Focus on: 

154- Key concepts and techniques being discussed 

155- The underlying intent, not the literal text 

156- Related topics that would provide useful context 

157 

158Return ONLY the search queries, one per line, without numbering or explanation. 

159Keep each query concise (2-6 words). 

160""" 

161 

162 def _parse_response(self, response: str, fallback: str) -> list[str]: 

163 """Parse LLM response into list of queries. 

164 

165 Args: 

166 response: Raw LLM response 

167 fallback: Fallback query if parsing fails 

168 

169 Returns: 

170 List of parsed queries 

171 """ 

172 # Split by newlines and clean up 

173 lines = response.strip().split("\n") 

174 queries = [] 

175 

176 for line in lines: 

177 # Remove common prefixes (numbering, bullets, etc.) 

178 cleaned = line.strip() 

179 cleaned = cleaned.lstrip("0123456789.-) ") 

180 cleaned = cleaned.strip('"\'') 

181 

182 if cleaned and len(cleaned) > 2: 

183 queries.append(cleaned) 

184 

185 # Ensure we have at least one query 

186 if not queries: 

187 queries = [fallback] 

188 

189 return queries 

190 

191 async def transform_with_context( 

192 self, 

193 user_input: str, 

194 conversation_context: str, 

195 num_queries: int | None = None, 

196 ) -> list[str]: 

197 """Transform with additional conversation context. 

198 

199 Args: 

200 user_input: The user's message 

201 conversation_context: Recent conversation history 

202 num_queries: Number of queries to generate 

203 

204 Returns: 

205 List of optimized search queries 

206 """ 

207 if not self.config.enabled: 

208 return [user_input] 

209 

210 if not self._initialized: 

211 raise RuntimeError( 

212 "QueryTransformer not initialized. Call initialize() first." 

213 ) 

214 

215 num = num_queries or self.config.num_queries 

216 

217 # Build enhanced prompt with context 

218 prompt = self._build_contextual_prompt( 

219 user_input, conversation_context, num 

220 ) 

221 

222 response = await self._llm.generate(prompt) 

223 queries = self._parse_response(response, user_input) 

224 

225 return queries[:num] 

226 

227 def _build_contextual_prompt( 

228 self, 

229 user_input: str, 

230 conversation_context: str, 

231 num_queries: int, 

232 ) -> str: 

233 """Build prompt with conversation context. 

234 

235 Args: 

236 user_input: User's message 

237 conversation_context: Recent conversation 

238 num_queries: Number of queries to generate 

239 

240 Returns: 

241 Prompt string for LLM 

242 """ 

243 domain_context = "" 

244 if self.config.domain_context: 

245 domain_context = f" in the context of {self.config.domain_context}" 

246 

247 return f"""Generate {num_queries} search queries to find relevant knowledge base content for the user's message{domain_context}. 

248 

249Recent conversation context: 

250{conversation_context} 

251 

252Current user message: "{user_input}" 

253 

254Focus on: 

255- Key concepts relevant to what the user is asking 

256- Context from the conversation that clarifies the query 

257- Related topics that would provide useful information 

258 

259Return ONLY the search queries, one per line, without numbering or explanation. 

260Keep each query concise (2-6 words). 

261""" 

262 

263 

264async def create_transformer(config: dict[str, Any]) -> QueryTransformer: 

265 """Create and initialize a QueryTransformer from config dict. 

266 

267 Convenience function for creating transformer from configuration. 

268 

269 Args: 

270 config: Configuration dictionary with TransformerConfig fields 

271 

272 Returns: 

273 Initialized QueryTransformer 

274 

275 Example: 

276 ```python 

277 transformer = await create_transformer({ 

278 "enabled": True, 

279 "llm_provider": "ollama", 

280 "llm_model": "llama3.2", 

281 "domain_context": "prompt engineering" 

282 }) 

283 ``` 

284 """ 

285 transformer_config = TransformerConfig(**config) 

286 transformer = QueryTransformer(transformer_config) 

287 await transformer.initialize() 

288 return transformer