Coverage for src/pullapprove/config.py: 70%

231 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-06-11 10:20 -0500

1import re 

2import tomllib 

3from collections.abc import Generator 

4from enum import Enum 

5from pathlib import Path 

6from typing import TYPE_CHECKING 

7 

8from pydantic import ( 

9 BaseModel, 

10 ConfigDict, 

11 Field, 

12 RootModel, 

13 field_validator, 

14) 

15from wcmatch import glob 

16 

17if TYPE_CHECKING: 

18 from .pullrequests import PullRequest, User 

19 

20CONFIG_FILENAME_PREFIX = "CODEREVIEW" 

21CONFIG_FILENAME = "CODEREVIEW.toml" 

22 

23 

24def _expand_aliases(values: list[str], aliases: dict[str, list[str]]) -> list[str]: 

25 """Replace alias references in a list with their mapped values.""" 

26 expanded: list[str] = [] 

27 for value in values: 

28 if isinstance(value, str) and value.startswith("$"): 

29 expanded.extend(aliases.get(value[1:], [])) 

30 else: 

31 expanded.append(value) 

32 

33 # Remove duplicates while preserving order 

34 return list(dict.fromkeys(expanded)) 

35 

36 

37class ReviewedForChoices(str, Enum): 

38 EMPTY = "" 

39 REQUIRED = "required" 

40 IGNORED = "ignored" 

41 

42 

43class OwnershipChoices(str, Enum): 

44 EMPTY = "" 

45 APPEND = "append" 

46 GLOBAL = "global" 

47 

48 

49class ScopeModel(BaseModel): 

50 model_config = ConfigDict(extra="forbid") 

51 

52 # Required fields 

53 name: str = Field(min_length=1) 

54 paths: list[str] = Field(min_length=1) 

55 

56 # Optional fields 

57 

58 # Expanded version of lines could be dict 

59 # with fnmatch, regex, exclude patterns, etc? 

60 code: list[str] = [] 

61 

62 # This only filtering field that can't be used with raw diff/files... 

63 # If we get into that, the others are: 

64 # - labels 

65 # - ref (have branches at the root level...) 

66 # - statuses 

67 # - dates 

68 # - body 

69 # - title 

70 # - other scopes 

71 # (this is how I ended up with expressions... 

72 # I'm not trying to build a general purpose workflow tool, 

73 # but I do need to support the legit use cases and AI/bot review is one, so is team hierarchy) 

74 authors: list[str] = [] 

75 

76 # (defaults should be the "empty" values) 

77 description: str = "" 

78 reviewers: list[str] = [] 

79 alternates: list[str] = [] 

80 cc: list[str] = [] 

81 

82 # Review scoring 

83 require: int = 0 

84 reviewed_for: ReviewedForChoices = ReviewedForChoices.EMPTY 

85 author_value: int = 0 

86 git_author_value: int = 0 

87 git_coauthor_value: int = 0 

88 

89 # How scopes are combined 

90 ownership: OwnershipChoices = OwnershipChoices.EMPTY 

91 

92 # Actionable items 

93 request: int = 0 

94 labels: list[str] = [] 

95 instructions: str = "" 

96 

97 def printed_name(self): 

98 match self.ownership: 

99 case OwnershipChoices.APPEND: 

100 return "+" + self.name 

101 case OwnershipChoices.GLOBAL: 

102 return "*" + self.name 

103 

104 return self.name 

105 

106 def __eq__(self, other): 

107 return self.name == other.name 

108 

109 def matches_path(self, path: Path): 

110 # TODO paths shouldn't start with / 

111 return glob.globmatch( 

112 path, 

113 self.paths, 

114 flags=glob.GLOBSTAR 

115 | glob.BRACE 

116 | glob.NEGATE 

117 | glob.IGNORECASE 

118 | glob.DOTGLOB, 

119 ) 

120 

121 def matches_code(self, code: str) -> Generator[dict[str, int], None, None]: 

122 patterns = getattr(self, "_line_regex_patterns", []) 

123 if not patterns: 

124 patterns = [re.compile(pattern, re.MULTILINE) for pattern in self.code] 

125 self._code_regex_patterns = patterns # Cache the compiled patterns 

126 

127 for pattern in patterns: 

128 for match in pattern.finditer(code): 

129 start_index = match.start() 

130 end_index = match.end() 

131 

132 start_line = code.count("\n", 0, start_index) + 1 

133 start_col = start_index - code.rfind("\n", 0, start_index) 

134 

135 end_line = code.count("\n", 0, end_index) + 1 

136 end_col = end_index - code.rfind("\n", 0, end_index) 

137 

138 yield { 

139 "start_line": start_line, 

140 "start_col": start_col, 

141 "end_line": end_line, 

142 "end_col": end_col, 

143 } 

144 

145 def matches_author(self, author: "User") -> bool: 

146 if not self.authors: 

147 # No authors specified, so assume it matches 

148 return True 

149 

150 negated_authors = [a[1:] for a in self.authors if a.startswith("!")] 

151 authors = [a for a in self.authors if not a.startswith("!")] 

152 

153 if author.username in negated_authors: 

154 # If the author is in the negated list, return False 

155 return False 

156 

157 if author.username in authors: 

158 # If the author is in the authors list, return True 

159 return True 

160 

161 return False 

162 

163 def enabled_for_pullrequest(self, pullrequest: "PullRequest") -> bool: 

164 # Paths/code are matched during diff parsing, 

165 # but we also consider authors in the context of a pull request so do that here. 

166 return self.matches_author(pullrequest.author) 

167 

168 

169class LargeScaleChangeModel(BaseModel): 

170 model_config = ConfigDict(extra="forbid") 

171 

172 # Note, an LSC only applies to diffs, not raw files, 

173 # because we have to know what *changed*. 

174 

175 # Pretty similar to a scope, but more manual. 

176 # There has to be at least one reviewer. So if a LSC config is not defined, an LSC PR error until you add one. 

177 require: int = 1 

178 reviewers: list[str] = [] # Field(min_length=1) 

179 # min_paths: int = 300 

180 # min_lines: int = 3000 

181 labels: list[str] = [] 

182 # really need author value too...? 

183 

184 

185class ConfigModel(BaseModel): 

186 model_config = ConfigDict(extra="forbid") 

187 

188 # Nothing is technically required 

189 extends: list[str] = [] 

190 template: bool = False 

191 branches: list[str] = [] 

192 aliases: dict[str, list[str]] = {} 

193 large_scale_change: LargeScaleChangeModel | None = None 

194 scopes: list[ScopeModel] = [] 

195 

196 @field_validator("scopes", mode="after") 

197 @classmethod 

198 def validate_unique_scope_names(cls, scopes): 

199 seen = set() 

200 for scope in scopes: 

201 if scope.name in seen: 

202 raise ValueError(f"Duplicate scope name: {scope.name}") 

203 seen.add(scope.name) 

204 

205 return scopes 

206 

207 @field_validator("extends", mode="before") 

208 @classmethod 

209 def validate_extends(cls, extends): 

210 for i, path in enumerate(extends): 

211 basename = Path(path).name 

212 if not basename.startswith(CONFIG_FILENAME_PREFIX): 

213 raise ValueError( 

214 f"Invalid extends path: {path}. It should start with '{CONFIG_FILENAME_PREFIX}'." 

215 ) 

216 return extends 

217 

218 def compiled_config( 

219 self, config_path: Path, other_configs: "ConfigModels" 

220 ) -> "ConfigModel": 

221 """ 

222 Merge extends, replace aliases. 

223 """ 

224 

225 if getattr(self, "_compiled_config", None) is not None: 

226 return self._compiled_config 

227 

228 # Create a copy of the data from what we have currently 

229 compiled_data = self.model_dump() 

230 

231 for extend_path in self.extends: 

232 if extend_path not in other_configs: 

233 raise ValueError(f"Config {extend_path} not found") 

234 

235 extended_config = other_configs[extend_path] 

236 extended_config.compiled_config(Path(extend_path), other_configs) 

237 

238 extended_config_dumped = extended_config.model_dump( 

239 include=["branches", "aliases", "scopes", "large_scale_change"] 

240 ) 

241 

242 compiled_data["scopes"] = ( 

243 extended_config_dumped["scopes"] + compiled_data["scopes"] 

244 ) 

245 compiled_data["large_scale_change"] = ( 

246 compiled_data["large_scale_change"] 

247 or extended_config_dumped["large_scale_change"] 

248 ) 

249 compiled_data["aliases"] = ( 

250 extended_config_dumped["aliases"] | compiled_data["aliases"] 

251 ) 

252 compiled_data["branches"] = ( 

253 extended_config_dumped["branches"] + compiled_data["branches"] 

254 ) 

255 

256 # Root aliases 

257 for field in ["extends", "branches"]: 

258 if field in compiled_data: 

259 compiled_data[field] = _expand_aliases( 

260 compiled_data[field], compiled_data["aliases"] 

261 ) 

262 

263 # Expand aliases for any aliasable list fields 

264 for scope in compiled_data["scopes"]: 

265 for field in [ 

266 "paths", 

267 "code", 

268 "authors", 

269 "reviewers", 

270 "alternates", 

271 "cc", 

272 "labels", 

273 ]: 

274 if field in scope: 

275 scope[field] = _expand_aliases( 

276 scope[field], compiled_data["aliases"] 

277 ) 

278 

279 if large_scale_change := compiled_data.get("large_scale_change"): 

280 for field in ["reviewers", "labels"]: 

281 large_scale_change[field] = _expand_aliases( 

282 large_scale_change[field], 

283 compiled_data["aliases"], 

284 ) 

285 

286 # Create a new config from the merged data 

287 self._compiled_config = ConfigModel.from_data( 

288 data=compiled_data, 

289 path=config_path, 

290 ) 

291 

292 return self._compiled_config 

293 

294 @classmethod 

295 def from_filesystem(cls, path: Path | str): 

296 with open(path, "rb") as f: 

297 return cls.from_data(tomllib.load(f), path) 

298 

299 @classmethod 

300 def from_content(cls, content: str, path: Path | str): 

301 return cls.from_data(tomllib.loads(content), path) 

302 

303 @classmethod 

304 def from_data(cls, data: dict, path: Path | str): 

305 # config = cls(path) 

306 

307 # config.data = data 

308 # config = ConfigModel(**config.data) 

309 

310 return cls(**data) 

311 

312 def matches_branches(self, base_branch: str, head_branch: str) -> bool: 

313 if not self.branches: 

314 # No branches specified, so assume it matches 

315 return True 

316 

317 for pattern in self.branches: 

318 splitter = "..." if "..." in pattern else ".." 

319 parts = pattern.split(splitter) 

320 base_pattern = parts[0] 

321 head_pattern = parts[1] if len(parts) > 1 else None 

322 

323 base_match = ( 

324 glob.globmatch(base_branch, base_pattern) if base_pattern else True 

325 ) 

326 head_match = ( 

327 glob.globmatch(head_branch, head_pattern) if head_pattern else True 

328 ) 

329 

330 if base_match and head_match: 

331 return True 

332 

333 return False 

334 

335 def enabled_for_pullrequest(self, pullrequest: "PullRequest") -> bool: 

336 return self.matches_branches( 

337 pullrequest.base_branch, 

338 pullrequest.head_branch, 

339 ) 

340 

341 # Kinda want the original toml if you dump? with comments etc 

342 # def as_toml(self) -> str: 

343 # """ 

344 # Convert the config to a TOML string. 

345 # """ 

346 # return tomllib.dumps(self.model_dump()) 

347 

348 # def matches_branch(self, base_branch: str, head_branch: str) -> bool: 

349 # for pattern in self.branches: 

350 # splitter = "..." if "..." in pattern else ".." 

351 # parts = pattern.split(splitter) 

352 # base_pattern = parts[0] 

353 # head_pattern = parts[1] if len(parts) > 1 else None 

354 

355 # base_match = fnmatch.fnmatch(base_branch, base_pattern) if base_pattern else True 

356 # head_match = fnmatch.fnmatch(head_branch, head_pattern) if head_pattern else True 

357 

358 # if base_match and head_match: 

359 # return True 

360 

361 # return False 

362 

363 

364class ConfigModels(RootModel): 

365 root: dict[str, ConfigModel] 

366 

367 # def __init__(self, configs: dict[str, CodeReviewConfig] = None): 

368 # if configs is None: 

369 # configs = {} 

370 # self = configs 

371 

372 # def __repr__(self): 

373 # return f"CodeReviewConfigs({self})" 

374 

375 @classmethod 

376 def from_configs_data(cls, data): 

377 """Load configs from a dict of data""" 

378 configs = cls(root={}) 

379 

380 for path, config_data in data.items(): 

381 config = ConfigModel.from_data(config_data, Path(path)) 

382 configs.add_config(config, Path(path)) 

383 

384 return configs 

385 

386 @classmethod 

387 def from_config_models(cls, models: dict[str, ConfigModel]): 

388 """Load configs from a dict of models""" 

389 configs = cls(root={}) 

390 

391 for path, config_model in models.items(): 

392 # config = ConfigModel.from_model(config_model, Path(path)) 

393 configs.add_config(config_model, Path(path)) 

394 

395 return configs 

396 

397 def get_config_models(self) -> dict[str, ConfigModel]: 

398 return dict(self.root.items()) 

399 

400 def add_config(self, config: ConfigModel, path: Path): 

401 self.root[str(path)] = config 

402 

403 def get_default_large_scale_change(self) -> ConfigModel: 

404 """Get the root config, which is the first one found in the list""" 

405 primary_config = CONFIG_FILENAME 

406 

407 if primary_config in self.root: 

408 if lsc := self.root[primary_config].large_scale_change: 

409 return lsc 

410 

411 return LargeScaleChangeModel() 

412 

413 def __bool__(self): 

414 return bool(self.root) 

415 

416 def __getitem__(self, key: str): 

417 return self.root[key] 

418 

419 def __contains__(self, key: str) -> bool: 

420 return key in self.root 

421 

422 def __len__(self) -> int: 

423 return len(self.root) 

424 

425 def compile_closest_config(self, file_path: Path) -> ConfigModel: 

426 """Find the closest config file to this file""" 

427 for parent in file_path.parents: 

428 parent_config_path = str(parent / CONFIG_FILENAME) 

429 

430 if parent_config_path in self.root: 

431 config = self.root[parent_config_path] 

432 

433 if config.template: 

434 # Skip templates 

435 continue 

436 

437 compiled = config.compiled_config(parent_config_path, self.root) 

438 

439 return compiled 

440 

441 raise ValueError(f"No config found for {file_path}") 

442 

443 def iter_compiled_configs(self) -> Generator[tuple[str, ConfigModel], None, None]: 

444 for config_path, config in self.root.items(): 

445 if config.template: 

446 # Skip templates 

447 continue 

448 

449 yield config.compiled_config(config_path, self.root) 

450 

451 def num_scopes(self) -> int: 

452 """ 

453 Count the total number of scopes across all configs. 

454 """ 

455 return sum(len(config.scopes) for config in self.root.values()) 

456 

457 def num_reviewers(self) -> int: 

458 """ 

459 Count the total number of reviewers across all configs. 

460 """ 

461 return sum( 

462 len(scope.reviewers) 

463 for config in self.root.values() 

464 for scope in config.scopes 

465 ) 

466 

467 def filter_for_pullrequest( 

468 self, 

469 pullrequest: "PullRequest", 

470 ) -> "ConfigModels": 

471 """ 

472 Look at all configs (including templates) and filter out 

473 configs and scopes based on branches, authors, etc. 

474 """ 

475 filtered_configs = {} 

476 

477 for config_path, config in self.root.items(): 

478 compiled_config = config.compiled_config(config_path, self.root) 

479 

480 if not compiled_config.enabled_for_pullrequest(pullrequest): 

481 # Remove the config from the list 

482 continue 

483 

484 # Collect the names of scopes to remove 

485 scopes_to_remove = set() 

486 for scope in compiled_config.scopes: 

487 if not scope.enabled_for_pullrequest(pullrequest): 

488 scopes_to_remove.add(scope.name) 

489 

490 # Create a filtered copy of the config data 

491 filtered_config_data = config.model_dump() 

492 

493 # Filter scopes in a simple, declarative way 

494 if "scopes" in filtered_config_data: 

495 filtered_config_data["scopes"] = [ 

496 scope 

497 for scope in filtered_config_data["scopes"] 

498 if scope["name"] not in scopes_to_remove 

499 ] 

500 

501 # Rebuild using the origial/modified raw data 

502 filtered_configs[config_path] = filtered_config_data 

503 

504 return ConfigModels.from_configs_data(filtered_configs) 

505 

506 # for config in configs.iter_compiled_configs(): 

507 # if config.enabled_for_pullrequest(self): 

508 # filtered_configs.add_config(config) 

509 

510 # # Paths/code are similar, but we need to iterate the diff to process them, 

511 # # so this is almost a pre-process for metadata of the PR 

512 # scopes_to_remove = [] 

513 # for scope in config.scopes: 

514 # if not scope.enabled_for_pullrequest(self): 

515 # scopes_to_remove.append(scope) 

516 # for scope in scopes_to_remove: 

517 # config.scopes.remove(scope)