Skip to content

Commit 6d93e65

Browse files
committed
fix: resolve CI failures and address CodeRabbit review feedback
- Fix pre-commit issues: formatting, import ordering, trailing whitespace - Convert webhook handlers to class-based pattern for consistency - Fix fail-open bug: TestCoverageCondition now returns violation on invalid regex - Fix inconsistent extension filtering between evaluate() and validate() - Fix max_hours=0 edge case in CommentResponseTimeCondition (falsy check) - Refactor DiffPattern/SecurityPattern into shared _PatchPatternCondition base - Remove redundant validate() overrides that duplicate BaseCondition logic
1 parent 981a2f2 commit 6d93e65

11 files changed

Lines changed: 162 additions & 194 deletions

src/main.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
from src.webhooks.handlers.deployment_status import DeploymentStatusEventHandler
2626
from src.webhooks.handlers.issue_comment import IssueCommentEventHandler
2727
from src.webhooks.handlers.pull_request import PullRequestEventHandler
28-
from src.webhooks.handlers.pull_request_review import handle_pull_request_review
29-
from src.webhooks.handlers.pull_request_review_thread import handle_pull_request_review_thread
28+
from src.webhooks.handlers.pull_request_review import PullRequestReviewEventHandler
29+
from src.webhooks.handlers.pull_request_review_thread import PullRequestReviewThreadEventHandler
3030
from src.webhooks.handlers.push import PushEventHandler
3131
from src.webhooks.router import router as webhook_router
3232

@@ -75,9 +75,12 @@ async def lifespan(_app: FastAPI) -> Any:
7575
deployment_review_handler = DeploymentReviewEventHandler()
7676
deployment_protection_rule_handler = DeploymentProtectionRuleEventHandler()
7777

78+
pull_request_review_handler = PullRequestReviewEventHandler()
79+
pull_request_review_thread_handler = PullRequestReviewThreadEventHandler()
80+
7881
dispatcher.register_handler(EventType.PULL_REQUEST, pull_request_handler.handle)
79-
dispatcher.register_handler(EventType.PULL_REQUEST_REVIEW, handle_pull_request_review)
80-
dispatcher.register_handler(EventType.PULL_REQUEST_REVIEW_THREAD, handle_pull_request_review_thread)
82+
dispatcher.register_handler(EventType.PULL_REQUEST_REVIEW, pull_request_review_handler.handle)
83+
dispatcher.register_handler(EventType.PULL_REQUEST_REVIEW_THREAD, pull_request_review_thread_handler.handle)
8184
dispatcher.register_handler(EventType.PUSH, push_handler.handle)
8285
dispatcher.register_handler(EventType.CHECK_RUN, check_run_handler.handle)
8386
dispatcher.register_handler(EventType.ISSUE_COMMENT, issue_comment_handler.handle)

src/rules/conditions/access_control_advanced.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
logger = logging.getLogger(__name__)
1010

11+
1112
class NoSelfApprovalCondition(BaseCondition):
1213
"""Validates that a PR author cannot approve their own PR."""
1314

@@ -34,7 +35,7 @@ async def evaluate(self, context: Any) -> list[Violation]:
3435
for review in reviews:
3536
review_state = review.get("state") if isinstance(review, dict) else getattr(review, "state", None)
3637
reviewer = review.get("author") if isinstance(review, dict) else getattr(review, "author", None)
37-
38+
3839
if review_state == "APPROVED" and reviewer == author:
3940
self_approved = True
4041
break
@@ -51,10 +52,6 @@ async def evaluate(self, context: Any) -> list[Violation]:
5152

5253
return []
5354

54-
async def validate(self, parameters: dict[str, Any], event: dict[str, Any]) -> bool:
55-
violations = await self.evaluate({"parameters": parameters, "event": event})
56-
return len(violations) == 0
57-
5855

5956
class CrossTeamApprovalCondition(BaseCondition):
6057
"""Validates that a PR has approvals from specific teams."""
@@ -74,19 +71,19 @@ async def evaluate(self, context: Any) -> list[Violation]:
7471
return []
7572

7673
reviews = event.get("reviews", [])
77-
74+
7875
# In a real implementation, we would map reviewers to their GitHub Teams
7976
# For now, we simulate this by checking if the required teams are in the requested_teams list
8077
# and if we have enough total approvals. A robust implementation would need a GraphQL call
8178
# to fetch user team memberships.
82-
79+
8380
pr_details = event.get("pull_request_details", {})
8481
requested_teams = pr_details.get("requested_teams", [])
8582
requested_team_slugs = [t.get("slug") for t in requested_teams if t.get("slug")]
86-
83+
8784
missing_teams = []
8885
for req_team in required_teams:
89-
clean_team = req_team.replace("@", "").split("/")[-1] # Clean org/team to just team
86+
clean_team = req_team.replace("@", "").split("/")[-1] # Clean org/team to just team
9087
if clean_team in requested_team_slugs:
9188
# Team was requested, now check if anyone approved (simplified check)
9289
has_approval = any(
@@ -110,7 +107,3 @@ async def evaluate(self, context: Any) -> list[Violation]:
110107
]
111108

112109
return []
113-
114-
async def validate(self, parameters: dict[str, Any], event: dict[str, Any]) -> bool:
115-
violations = await self.evaluate({"parameters": parameters, "event": event})
116-
return len(violations) == 0

src/rules/conditions/compliance.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
logger = logging.getLogger(__name__)
1010

11+
1112
class SignedCommitsCondition(BaseCondition):
1213
"""Validates that all commits in a PR are cryptographically signed."""
1314

@@ -34,9 +35,15 @@ async def evaluate(self, context: Any) -> list[Violation]:
3435
unsigned_shas = []
3536
for commit in commits:
3637
# We will need to update the GraphQL query to fetch verificationStatus
37-
is_verified = commit.get("is_verified", False) if isinstance(commit, dict) else getattr(commit, "is_verified", False)
38+
is_verified = (
39+
commit.get("is_verified", False) if isinstance(commit, dict) else getattr(commit, "is_verified", False)
40+
)
3841
if not is_verified:
39-
sha = str(commit.get("oid", "unknown")) if isinstance(commit, dict) else str(getattr(commit, "oid", "unknown"))
42+
sha = (
43+
str(commit.get("oid", "unknown"))
44+
if isinstance(commit, dict)
45+
else str(getattr(commit, "oid", "unknown"))
46+
)
4047
unsigned_shas.append(sha[:7])
4148

4249
if unsigned_shas:
@@ -51,10 +58,6 @@ async def evaluate(self, context: Any) -> list[Violation]:
5158

5259
return []
5360

54-
async def validate(self, parameters: dict[str, Any], event: dict[str, Any]) -> bool:
55-
violations = await self.evaluate({"parameters": parameters, "event": event})
56-
return len(violations) == 0
57-
5861

5962
class ChangelogRequiredCondition(BaseCondition):
6063
"""Validates that a CHANGELOG update is included if source files are modified."""
@@ -83,7 +86,7 @@ async def evaluate(self, context: Any) -> list[Violation]:
8386
filename = f.get("filename", "")
8487
if not filename:
8588
continue
86-
89+
8790
# Check if it's a changelog file
8891
if "CHANGELOG" in filename.upper() or filename.startswith(".changeset/"):
8992
changelog_changed = True
@@ -101,7 +104,3 @@ async def evaluate(self, context: Any) -> list[Violation]:
101104
]
102105

103106
return []
104-
105-
async def validate(self, parameters: dict[str, Any], event: dict[str, Any]) -> bool:
106-
violations = await self.evaluate({"parameters": parameters, "event": event})
107-
return len(violations) == 0

src/rules/conditions/filesystem.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,14 @@ async def evaluate(self, context: Any) -> list[Violation]:
310310
compiled_pattern = re.compile(test_pattern)
311311
except re.error:
312312
logger.error(f"Invalid test_file_pattern regex: {test_pattern}")
313-
return []
313+
return [
314+
Violation(
315+
rule_description=self.description,
316+
severity=Severity.MEDIUM,
317+
message=f"Invalid test_file_pattern regex: '{test_pattern}'",
318+
how_to_fix="Fix the regular expression in the 'test_file_pattern' parameter.",
319+
)
320+
]
314321

315322
for f in changed_files:
316323
filename = f.get("filename", "")
@@ -355,14 +362,14 @@ async def validate(self, parameters: dict[str, Any], event: dict[str, Any]) -> b
355362
try:
356363
compiled_pattern = re.compile(test_pattern)
357364
except re.error:
358-
return True
365+
return False
359366

360367
source_modified = False
361368
test_modified = False
362369

363370
for f in changed_files:
364371
filename = f.get("filename", "")
365-
if not filename or filename.endswith(".md") or filename.endswith(".yaml"):
372+
if not filename or filename.endswith((".md", ".txt", ".yaml", ".json")):
366373
continue
367374

368375
if compiled_pattern.search(filename):

src/rules/conditions/pull_request.py

Lines changed: 50 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -367,22 +367,30 @@ async def validate(self, parameters: dict[str, Any], event: dict[str, Any]) -> b
367367
return bool(_ISSUE_REF_PATTERN.search(combined))
368368

369369

370-
class DiffPatternCondition(BaseCondition):
371-
"""Validates that a PR diff does not contain specified restricted patterns."""
370+
class _PatchPatternCondition(BaseCondition):
371+
"""Base class for conditions that match regex patterns against PR diff patches.
372372
373-
name = "diff_pattern"
374-
description = "Checks if code changes contain restricted patterns or fail to contain required patterns."
375-
parameter_patterns = ["diff_restricted_patterns"]
376-
event_types = ["pull_request"]
377-
examples = [{"diff_restricted_patterns": ["console\\.log", "TODO:"]}]
373+
Subclasses configure the parameter key, violation severity, and message format.
374+
"""
375+
376+
_pattern_param_key: str = ""
377+
_violation_severity: Severity = Severity.MEDIUM
378+
379+
def _make_message(self, matched: list[str], filename: str) -> str:
380+
"""Return the violation message. Override for custom wording."""
381+
return f"Patterns {matched} found in added lines of {filename}"
382+
383+
def _make_how_to_fix(self) -> str:
384+
"""Return the how_to_fix text. Override for custom wording."""
385+
return "Remove the matched patterns from your code changes."
378386

379387
async def evaluate(self, context: Any) -> list[Violation]:
380-
"""Evaluate diff-pattern condition."""
388+
"""Evaluate patch-pattern condition."""
381389
parameters = context.get("parameters", {})
382390
event = context.get("event", {})
383391

384-
restricted_patterns = parameters.get("diff_restricted_patterns")
385-
if not restricted_patterns or not isinstance(restricted_patterns, list):
392+
patterns = parameters.get(self._pattern_param_key)
393+
if not patterns or not isinstance(patterns, list):
386394
return []
387395

388396
changed_files = event.get("changed_files", [])
@@ -397,98 +405,73 @@ async def evaluate(self, context: Any) -> list[Violation]:
397405
if not patch:
398406
continue
399407

400-
matched = match_patterns_in_patch(patch, restricted_patterns)
408+
matched = match_patterns_in_patch(patch, patterns)
401409
if matched:
402410
filename = file_info.get("filename", "unknown")
403411
violations.append(
404412
Violation(
405413
rule_description=self.description,
406-
severity=Severity.MEDIUM,
407-
message=f"Restricted patterns {matched} found in added lines of {filename}",
408-
how_to_fix="Remove the restricted patterns from your code changes.",
414+
severity=self._violation_severity,
415+
message=self._make_message(matched, filename),
416+
how_to_fix=self._make_how_to_fix(),
409417
)
410418
)
411419

412420
return violations
413421

414422
async def validate(self, parameters: dict[str, Any], event: dict[str, Any]) -> bool:
415423
"""Legacy validation interface."""
416-
restricted_patterns = parameters.get("diff_restricted_patterns")
417-
if not restricted_patterns or not isinstance(restricted_patterns, list):
424+
patterns = parameters.get(self._pattern_param_key)
425+
if not patterns or not isinstance(patterns, list):
418426
return True
419427

420428
changed_files = event.get("changed_files", [])
421429
from src.rules.utils.diff import match_patterns_in_patch
422430

423431
for file_info in changed_files:
424432
patch = file_info.get("patch")
425-
if patch and match_patterns_in_patch(patch, restricted_patterns):
433+
if patch and match_patterns_in_patch(patch, patterns):
426434
return False
427435

428436
return True
429437

430438

431-
class SecurityPatternCondition(BaseCondition):
432-
"""Detects security-sensitive patterns (like API keys) in code changes."""
439+
class DiffPatternCondition(_PatchPatternCondition):
440+
"""Validates that a PR diff does not contain specified restricted patterns."""
433441

434-
name = "security_pattern"
435-
description = "Detects hardcoded secrets, API keys, or sensitive data in PR diffs."
436-
parameter_patterns = ["security_patterns"]
442+
name = "diff_pattern"
443+
description = "Checks if code changes contain restricted patterns or fail to contain required patterns."
444+
parameter_patterns = ["diff_restricted_patterns"]
437445
event_types = ["pull_request"]
438-
examples = [{"security_patterns": ["api_key", "secret", "password", "token"]}]
439-
440-
async def evaluate(self, context: Any) -> list[Violation]:
441-
"""Evaluate security-pattern condition."""
442-
parameters = context.get("parameters", {})
443-
event = context.get("event", {})
444-
445-
security_patterns = parameters.get("security_patterns")
446-
if not security_patterns or not isinstance(security_patterns, list):
447-
return []
446+
examples = [{"diff_restricted_patterns": ["console\\.log", "TODO:"]}]
448447

449-
changed_files = event.get("changed_files", [])
450-
if not changed_files:
451-
return []
448+
_pattern_param_key = "diff_restricted_patterns"
449+
_violation_severity = Severity.MEDIUM
452450

453-
from src.rules.utils.diff import match_patterns_in_patch
451+
def _make_message(self, matched: list[str], filename: str) -> str:
452+
return f"Restricted patterns {matched} found in added lines of {filename}"
454453

455-
violations = []
456-
for file_info in changed_files:
457-
patch = file_info.get("patch")
458-
if not patch:
459-
continue
454+
def _make_how_to_fix(self) -> str:
455+
return "Remove the restricted patterns from your code changes."
460456

461-
# In a real scenario, this would use a more robust secrets scanner.
462-
# Here we just use the diff matcher with the provided regex/string patterns.
463-
matched = match_patterns_in_patch(patch, security_patterns)
464-
if matched:
465-
filename = file_info.get("filename", "unknown")
466-
violations.append(
467-
Violation(
468-
rule_description=self.description,
469-
severity=Severity.CRITICAL,
470-
message=f"Security-sensitive patterns {matched} detected in {filename}",
471-
how_to_fix="Remove hardcoded secrets or sensitive patterns from the code.",
472-
)
473-
)
474457

475-
return violations
458+
class SecurityPatternCondition(_PatchPatternCondition):
459+
"""Detects security-sensitive patterns (like API keys) in code changes."""
476460

477-
async def validate(self, parameters: dict[str, Any], event: dict[str, Any]) -> bool:
478-
"""Legacy validation interface."""
479-
security_patterns = parameters.get("security_patterns")
480-
if not security_patterns or not isinstance(security_patterns, list):
481-
return True
461+
name = "security_pattern"
462+
description = "Detects hardcoded secrets, API keys, or sensitive data in PR diffs."
463+
parameter_patterns = ["security_patterns"]
464+
event_types = ["pull_request"]
465+
examples = [{"security_patterns": ["api_key", "secret", "password", "token"]}]
482466

483-
changed_files = event.get("changed_files", [])
484-
from src.rules.utils.diff import match_patterns_in_patch
467+
_pattern_param_key = "security_patterns"
468+
_violation_severity = Severity.CRITICAL
485469

486-
for file_info in changed_files:
487-
patch = file_info.get("patch")
488-
if patch and match_patterns_in_patch(patch, security_patterns):
489-
return False
470+
def _make_message(self, matched: list[str], filename: str) -> str:
471+
return f"Security-sensitive patterns {matched} detected in {filename}"
490472

491-
return True
473+
def _make_how_to_fix(self) -> str:
474+
return "Remove hardcoded secrets or sensitive patterns from the code."
492475

493476

494477
class UnresolvedCommentsCondition(BaseCondition):

src/rules/conditions/temporal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ async def evaluate(self, context: Any) -> list[Violation]:
261261
event = context.get("event", {})
262262

263263
max_hours = parameters.get("max_comment_response_time_hours")
264-
if not max_hours:
264+
if max_hours is None:
265265
return []
266266

267267
review_threads = event.get("review_threads", [])

0 commit comments

Comments
 (0)