Coverage for src/pullapprove/config.py: 70%
231 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-06-11 10:20 -0500
« prev ^ index » next coverage.py v7.8.2, created at 2025-06-11 10:20 -0500
1import re
2import tomllib
3from collections.abc import Generator
4from enum import Enum
5from pathlib import Path
6from typing import TYPE_CHECKING
8from pydantic import (
9 BaseModel,
10 ConfigDict,
11 Field,
12 RootModel,
13 field_validator,
14)
15from wcmatch import glob
17if TYPE_CHECKING:
18 from .pullrequests import PullRequest, User
20CONFIG_FILENAME_PREFIX = "CODEREVIEW"
21CONFIG_FILENAME = "CODEREVIEW.toml"
24def _expand_aliases(values: list[str], aliases: dict[str, list[str]]) -> list[str]:
25 """Replace alias references in a list with their mapped values."""
26 expanded: list[str] = []
27 for value in values:
28 if isinstance(value, str) and value.startswith("$"):
29 expanded.extend(aliases.get(value[1:], []))
30 else:
31 expanded.append(value)
33 # Remove duplicates while preserving order
34 return list(dict.fromkeys(expanded))
37class ReviewedForChoices(str, Enum):
38 EMPTY = ""
39 REQUIRED = "required"
40 IGNORED = "ignored"
43class OwnershipChoices(str, Enum):
44 EMPTY = ""
45 APPEND = "append"
46 GLOBAL = "global"
49class ScopeModel(BaseModel):
50 model_config = ConfigDict(extra="forbid")
52 # Required fields
53 name: str = Field(min_length=1)
54 paths: list[str] = Field(min_length=1)
56 # Optional fields
58 # Expanded version of lines could be dict
59 # with fnmatch, regex, exclude patterns, etc?
60 code: list[str] = []
62 # This only filtering field that can't be used with raw diff/files...
63 # If we get into that, the others are:
64 # - labels
65 # - ref (have branches at the root level...)
66 # - statuses
67 # - dates
68 # - body
69 # - title
70 # - other scopes
71 # (this is how I ended up with expressions...
72 # I'm not trying to build a general purpose workflow tool,
73 # but I do need to support the legit use cases and AI/bot review is one, so is team hierarchy)
74 authors: list[str] = []
76 # (defaults should be the "empty" values)
77 description: str = ""
78 reviewers: list[str] = []
79 alternates: list[str] = []
80 cc: list[str] = []
82 # Review scoring
83 require: int = 0
84 reviewed_for: ReviewedForChoices = ReviewedForChoices.EMPTY
85 author_value: int = 0
86 git_author_value: int = 0
87 git_coauthor_value: int = 0
89 # How scopes are combined
90 ownership: OwnershipChoices = OwnershipChoices.EMPTY
92 # Actionable items
93 request: int = 0
94 labels: list[str] = []
95 instructions: str = ""
97 def printed_name(self):
98 match self.ownership:
99 case OwnershipChoices.APPEND:
100 return "+" + self.name
101 case OwnershipChoices.GLOBAL:
102 return "*" + self.name
104 return self.name
106 def __eq__(self, other):
107 return self.name == other.name
109 def matches_path(self, path: Path):
110 # TODO paths shouldn't start with /
111 return glob.globmatch(
112 path,
113 self.paths,
114 flags=glob.GLOBSTAR
115 | glob.BRACE
116 | glob.NEGATE
117 | glob.IGNORECASE
118 | glob.DOTGLOB,
119 )
121 def matches_code(self, code: str) -> Generator[dict[str, int], None, None]:
122 patterns = getattr(self, "_line_regex_patterns", [])
123 if not patterns:
124 patterns = [re.compile(pattern, re.MULTILINE) for pattern in self.code]
125 self._code_regex_patterns = patterns # Cache the compiled patterns
127 for pattern in patterns:
128 for match in pattern.finditer(code):
129 start_index = match.start()
130 end_index = match.end()
132 start_line = code.count("\n", 0, start_index) + 1
133 start_col = start_index - code.rfind("\n", 0, start_index)
135 end_line = code.count("\n", 0, end_index) + 1
136 end_col = end_index - code.rfind("\n", 0, end_index)
138 yield {
139 "start_line": start_line,
140 "start_col": start_col,
141 "end_line": end_line,
142 "end_col": end_col,
143 }
145 def matches_author(self, author: "User") -> bool:
146 if not self.authors:
147 # No authors specified, so assume it matches
148 return True
150 negated_authors = [a[1:] for a in self.authors if a.startswith("!")]
151 authors = [a for a in self.authors if not a.startswith("!")]
153 if author.username in negated_authors:
154 # If the author is in the negated list, return False
155 return False
157 if author.username in authors:
158 # If the author is in the authors list, return True
159 return True
161 return False
163 def enabled_for_pullrequest(self, pullrequest: "PullRequest") -> bool:
164 # Paths/code are matched during diff parsing,
165 # but we also consider authors in the context of a pull request so do that here.
166 return self.matches_author(pullrequest.author)
169class LargeScaleChangeModel(BaseModel):
170 model_config = ConfigDict(extra="forbid")
172 # Note, an LSC only applies to diffs, not raw files,
173 # because we have to know what *changed*.
175 # Pretty similar to a scope, but more manual.
176 # There has to be at least one reviewer. So if a LSC config is not defined, an LSC PR error until you add one.
177 require: int = 1
178 reviewers: list[str] = [] # Field(min_length=1)
179 # min_paths: int = 300
180 # min_lines: int = 3000
181 labels: list[str] = []
182 # really need author value too...?
185class ConfigModel(BaseModel):
186 model_config = ConfigDict(extra="forbid")
188 # Nothing is technically required
189 extends: list[str] = []
190 template: bool = False
191 branches: list[str] = []
192 aliases: dict[str, list[str]] = {}
193 large_scale_change: LargeScaleChangeModel | None = None
194 scopes: list[ScopeModel] = []
196 @field_validator("scopes", mode="after")
197 @classmethod
198 def validate_unique_scope_names(cls, scopes):
199 seen = set()
200 for scope in scopes:
201 if scope.name in seen:
202 raise ValueError(f"Duplicate scope name: {scope.name}")
203 seen.add(scope.name)
205 return scopes
207 @field_validator("extends", mode="before")
208 @classmethod
209 def validate_extends(cls, extends):
210 for i, path in enumerate(extends):
211 basename = Path(path).name
212 if not basename.startswith(CONFIG_FILENAME_PREFIX):
213 raise ValueError(
214 f"Invalid extends path: {path}. It should start with '{CONFIG_FILENAME_PREFIX}'."
215 )
216 return extends
218 def compiled_config(
219 self, config_path: Path, other_configs: "ConfigModels"
220 ) -> "ConfigModel":
221 """
222 Merge extends, replace aliases.
223 """
225 if getattr(self, "_compiled_config", None) is not None:
226 return self._compiled_config
228 # Create a copy of the data from what we have currently
229 compiled_data = self.model_dump()
231 for extend_path in self.extends:
232 if extend_path not in other_configs:
233 raise ValueError(f"Config {extend_path} not found")
235 extended_config = other_configs[extend_path]
236 extended_config.compiled_config(Path(extend_path), other_configs)
238 extended_config_dumped = extended_config.model_dump(
239 include=["branches", "aliases", "scopes", "large_scale_change"]
240 )
242 compiled_data["scopes"] = (
243 extended_config_dumped["scopes"] + compiled_data["scopes"]
244 )
245 compiled_data["large_scale_change"] = (
246 compiled_data["large_scale_change"]
247 or extended_config_dumped["large_scale_change"]
248 )
249 compiled_data["aliases"] = (
250 extended_config_dumped["aliases"] | compiled_data["aliases"]
251 )
252 compiled_data["branches"] = (
253 extended_config_dumped["branches"] + compiled_data["branches"]
254 )
256 # Root aliases
257 for field in ["extends", "branches"]:
258 if field in compiled_data:
259 compiled_data[field] = _expand_aliases(
260 compiled_data[field], compiled_data["aliases"]
261 )
263 # Expand aliases for any aliasable list fields
264 for scope in compiled_data["scopes"]:
265 for field in [
266 "paths",
267 "code",
268 "authors",
269 "reviewers",
270 "alternates",
271 "cc",
272 "labels",
273 ]:
274 if field in scope:
275 scope[field] = _expand_aliases(
276 scope[field], compiled_data["aliases"]
277 )
279 if large_scale_change := compiled_data.get("large_scale_change"):
280 for field in ["reviewers", "labels"]:
281 large_scale_change[field] = _expand_aliases(
282 large_scale_change[field],
283 compiled_data["aliases"],
284 )
286 # Create a new config from the merged data
287 self._compiled_config = ConfigModel.from_data(
288 data=compiled_data,
289 path=config_path,
290 )
292 return self._compiled_config
294 @classmethod
295 def from_filesystem(cls, path: Path | str):
296 with open(path, "rb") as f:
297 return cls.from_data(tomllib.load(f), path)
299 @classmethod
300 def from_content(cls, content: str, path: Path | str):
301 return cls.from_data(tomllib.loads(content), path)
303 @classmethod
304 def from_data(cls, data: dict, path: Path | str):
305 # config = cls(path)
307 # config.data = data
308 # config = ConfigModel(**config.data)
310 return cls(**data)
312 def matches_branches(self, base_branch: str, head_branch: str) -> bool:
313 if not self.branches:
314 # No branches specified, so assume it matches
315 return True
317 for pattern in self.branches:
318 splitter = "..." if "..." in pattern else ".."
319 parts = pattern.split(splitter)
320 base_pattern = parts[0]
321 head_pattern = parts[1] if len(parts) > 1 else None
323 base_match = (
324 glob.globmatch(base_branch, base_pattern) if base_pattern else True
325 )
326 head_match = (
327 glob.globmatch(head_branch, head_pattern) if head_pattern else True
328 )
330 if base_match and head_match:
331 return True
333 return False
335 def enabled_for_pullrequest(self, pullrequest: "PullRequest") -> bool:
336 return self.matches_branches(
337 pullrequest.base_branch,
338 pullrequest.head_branch,
339 )
341 # Kinda want the original toml if you dump? with comments etc
342 # def as_toml(self) -> str:
343 # """
344 # Convert the config to a TOML string.
345 # """
346 # return tomllib.dumps(self.model_dump())
348 # def matches_branch(self, base_branch: str, head_branch: str) -> bool:
349 # for pattern in self.branches:
350 # splitter = "..." if "..." in pattern else ".."
351 # parts = pattern.split(splitter)
352 # base_pattern = parts[0]
353 # head_pattern = parts[1] if len(parts) > 1 else None
355 # base_match = fnmatch.fnmatch(base_branch, base_pattern) if base_pattern else True
356 # head_match = fnmatch.fnmatch(head_branch, head_pattern) if head_pattern else True
358 # if base_match and head_match:
359 # return True
361 # return False
364class ConfigModels(RootModel):
365 root: dict[str, ConfigModel]
367 # def __init__(self, configs: dict[str, CodeReviewConfig] = None):
368 # if configs is None:
369 # configs = {}
370 # self = configs
372 # def __repr__(self):
373 # return f"CodeReviewConfigs({self})"
375 @classmethod
376 def from_configs_data(cls, data):
377 """Load configs from a dict of data"""
378 configs = cls(root={})
380 for path, config_data in data.items():
381 config = ConfigModel.from_data(config_data, Path(path))
382 configs.add_config(config, Path(path))
384 return configs
386 @classmethod
387 def from_config_models(cls, models: dict[str, ConfigModel]):
388 """Load configs from a dict of models"""
389 configs = cls(root={})
391 for path, config_model in models.items():
392 # config = ConfigModel.from_model(config_model, Path(path))
393 configs.add_config(config_model, Path(path))
395 return configs
397 def get_config_models(self) -> dict[str, ConfigModel]:
398 return dict(self.root.items())
400 def add_config(self, config: ConfigModel, path: Path):
401 self.root[str(path)] = config
403 def get_default_large_scale_change(self) -> ConfigModel:
404 """Get the root config, which is the first one found in the list"""
405 primary_config = CONFIG_FILENAME
407 if primary_config in self.root:
408 if lsc := self.root[primary_config].large_scale_change:
409 return lsc
411 return LargeScaleChangeModel()
413 def __bool__(self):
414 return bool(self.root)
416 def __getitem__(self, key: str):
417 return self.root[key]
419 def __contains__(self, key: str) -> bool:
420 return key in self.root
422 def __len__(self) -> int:
423 return len(self.root)
425 def compile_closest_config(self, file_path: Path) -> ConfigModel:
426 """Find the closest config file to this file"""
427 for parent in file_path.parents:
428 parent_config_path = str(parent / CONFIG_FILENAME)
430 if parent_config_path in self.root:
431 config = self.root[parent_config_path]
433 if config.template:
434 # Skip templates
435 continue
437 compiled = config.compiled_config(parent_config_path, self.root)
439 return compiled
441 raise ValueError(f"No config found for {file_path}")
443 def iter_compiled_configs(self) -> Generator[tuple[str, ConfigModel], None, None]:
444 for config_path, config in self.root.items():
445 if config.template:
446 # Skip templates
447 continue
449 yield config.compiled_config(config_path, self.root)
451 def num_scopes(self) -> int:
452 """
453 Count the total number of scopes across all configs.
454 """
455 return sum(len(config.scopes) for config in self.root.values())
457 def num_reviewers(self) -> int:
458 """
459 Count the total number of reviewers across all configs.
460 """
461 return sum(
462 len(scope.reviewers)
463 for config in self.root.values()
464 for scope in config.scopes
465 )
467 def filter_for_pullrequest(
468 self,
469 pullrequest: "PullRequest",
470 ) -> "ConfigModels":
471 """
472 Look at all configs (including templates) and filter out
473 configs and scopes based on branches, authors, etc.
474 """
475 filtered_configs = {}
477 for config_path, config in self.root.items():
478 compiled_config = config.compiled_config(config_path, self.root)
480 if not compiled_config.enabled_for_pullrequest(pullrequest):
481 # Remove the config from the list
482 continue
484 # Collect the names of scopes to remove
485 scopes_to_remove = set()
486 for scope in compiled_config.scopes:
487 if not scope.enabled_for_pullrequest(pullrequest):
488 scopes_to_remove.add(scope.name)
490 # Create a filtered copy of the config data
491 filtered_config_data = config.model_dump()
493 # Filter scopes in a simple, declarative way
494 if "scopes" in filtered_config_data:
495 filtered_config_data["scopes"] = [
496 scope
497 for scope in filtered_config_data["scopes"]
498 if scope["name"] not in scopes_to_remove
499 ]
501 # Rebuild using the origial/modified raw data
502 filtered_configs[config_path] = filtered_config_data
504 return ConfigModels.from_configs_data(filtered_configs)
506 # for config in configs.iter_compiled_configs():
507 # if config.enabled_for_pullrequest(self):
508 # filtered_configs.add_config(config)
510 # # Paths/code are similar, but we need to iterate the diff to process them,
511 # # so this is almost a pre-process for metadata of the PR
512 # scopes_to_remove = []
513 # for scope in config.scopes:
514 # if not scope.enabled_for_pullrequest(self):
515 # scopes_to_remove.append(scope)
516 # for scope in scopes_to_remove:
517 # config.scopes.remove(scope)