Skip to content

Commit 31ec71c

Browse files
Merge pull request #614 from NHSDigital/eli-595
Eli 595
2 parents cfc1827 + 3e91c9f commit 31ec71c

4 files changed

Lines changed: 62 additions & 13 deletions

File tree

src/rules_validation_api/validators/iteration_validator.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,46 @@ def transform_actions_mapper(cls, action_mapper: ActionsMapper) -> ActionsMapper
9797
action_mapper.root = new_root
9898
return action_mapper
9999

100+
@model_validator(mode="after")
101+
def validate_rule_cohort_labels_against_iteration_cohorts(self) -> typing.Self:
102+
allowed_labels = {c.cohort_label for c in self.iteration_cohorts}
103+
line_errors: list[InitErrorDetails] = []
104+
105+
# Pre compute allowed label string once
106+
allowed_str = ", ".join(sorted(allowed_labels)) if allowed_labels else None
107+
108+
for idx, rule in enumerate(self.iteration_rules):
109+
if not rule.cohort_label:
110+
continue
111+
112+
for label in rule.parsed_cohort_labels:
113+
if label in allowed_labels:
114+
continue
115+
116+
# Build error message
117+
error_message = (
118+
f"Invalid cohort_label value '{label}'. Allowed values: {allowed_str}."
119+
if allowed_str
120+
else (
121+
f"Invalid cohort_label value '{label}'. "
122+
"No iteration cohorts are defined, so no labels are allowed."
123+
)
124+
)
125+
126+
line_errors.append(
127+
InitErrorDetails(
128+
type="value_error",
129+
loc=("iteration_rules", idx, "cohort_label"),
130+
input=rule.cohort_label,
131+
ctx={"error": error_message},
132+
)
133+
)
134+
135+
if line_errors:
136+
raise ValidationError.from_exception_data(title="IterationValidation", line_errors=line_errors)
137+
138+
return self
139+
100140
@model_validator(mode="after")
101141
def action_mapper_validation(self) -> typing.Self:
102142
all_errors = []

tests/unit/validation/test_campaign_config_validator.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,7 @@ class TestMandatoryFieldsSchemaValidations:
1313
def test_campaign_config_with_only_mandatory_fields_configuration(
1414
self, valid_campaign_config_with_only_mandatory_fields
1515
):
16-
try:
17-
CampaignConfigValidation(**valid_campaign_config_with_only_mandatory_fields)
18-
except ValidationError as e:
19-
pytest.fail(f"Unexpected error during model instantiation: {e}")
16+
CampaignConfigValidation(**valid_campaign_config_with_only_mandatory_fields)
2017

2118
@pytest.mark.parametrize(
2219
"mandatory_field",

tests/unit/validation/test_iteration_rules_validator.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,7 @@ class TestMandatoryFieldsSchemaValidations:
99
def test_campaign_config_with_only_mandatory_fields_configuration(
1010
self, valid_iteration_rule_with_only_mandatory_fields
1111
):
12-
try:
13-
IterationRuleValidation(**valid_iteration_rule_with_only_mandatory_fields)
14-
except ValidationError as e:
15-
pytest.fail(f"Unexpected error during model instantiation: {e}")
12+
IterationRuleValidation(**valid_iteration_rule_with_only_mandatory_fields)
1613

1714
@pytest.mark.parametrize(
1815
"mandatory_field",

tests/unit/validation/test_iteration_validator.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,7 @@ class TestMandatoryFieldsSchemaValidations:
1414
def test_campaign_config_with_only_mandatory_fields_configuration(
1515
self, valid_campaign_config_with_only_mandatory_fields
1616
):
17-
try:
18-
IterationValidation(**(valid_campaign_config_with_only_mandatory_fields["Iterations"][0]))
19-
except ValidationError as e:
20-
pytest.fail(f"Unexpected error during model instantiation: {e}")
17+
IterationValidation(**(valid_campaign_config_with_only_mandatory_fields["Iterations"][0]))
2118

2219
@pytest.mark.parametrize(
2320
"mandatory_field",
@@ -556,7 +553,7 @@ def test_iteration_full_datetime_validation( # noqa : PLR0913
556553
data = valid_campaign_config_with_only_mandatory_fields.copy()
557554

558555
if default_time_iteration_input:
559-
data["iteration_time"] = default_time_iteration_input
556+
data["IterationTime"] = default_time_iteration_input
560557

561558
data["Iterations"] = [iteration_data]
562559

@@ -570,3 +567,21 @@ def test_iteration_full_datetime_validation( # noqa : PLR0913
570567
f"Failed! Input: {iteration_time_input}, Default: {default_time_iteration_input}. "
571568
f"Expected {expected_date_time} but got {result}"
572569
)
570+
571+
def test_iteration_rules_having_invalid_cohort_labels_throws_error(
572+
self,
573+
valid_iteration_with_only_mandatory_fields,
574+
valid_iteration_rule_with_only_mandatory_fields,
575+
valid_iteration_cohorts,
576+
):
577+
data = valid_iteration_with_only_mandatory_fields.copy()
578+
data["IterationRules"] = [valid_iteration_rule_with_only_mandatory_fields]
579+
data["IterationCohorts"] = [valid_iteration_cohorts()]
580+
data["IterationRules"][0]["CohortLabel"] = "label_2"
581+
582+
with pytest.raises(ValidationError) as exc_info:
583+
IterationValidation(**data)
584+
585+
errors = exc_info.value.errors()
586+
# Ensure at least one error is specifically about the invalid CohortLabel in IterationRules[0]
587+
assert any(err.get("loc", [])[:3] == ("iteration_rules", 0, "cohort_label") for err in errors)

0 commit comments

Comments
 (0)