Skip to content

Commit de8ebac

Browse files
authored
ELI-351 and ELI-342: Refactors and fixes Cohort Schema Mismatch (#253)
* Adds campaign_evaluator and tests * Adds person_data_reader and tests * Injects person data reader and campaign processor into eligibility calculator * ELI-342 Dynamo Cohort Schema Mismatch * ELI-342: Fixes usage of person cohorts method and tests
1 parent ea45743 commit de8ebac

13 files changed

Lines changed: 298 additions & 102 deletions

File tree

src/eligibility_signposting_api/services/calculators/eligibility_calculator.py

Lines changed: 11 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from typing import TYPE_CHECKING, Any
99

1010
from eligibility_signposting_api.audit.audit_context import AuditContext
11+
from eligibility_signposting_api.services.processors.campaign_evaluator import CampaignEvaluator
12+
from eligibility_signposting_api.services.processors.person_data_reader import PersonDataReader
1113

1214
if TYPE_CHECKING:
1315
from eligibility_signposting_api.model.rules import (
@@ -58,42 +60,10 @@ class EligibilityCalculator:
5860
person_data: Row
5961
campaign_configs: Collection[rules.CampaignConfig]
6062

61-
results: list[eligibility_status.Condition] = field(default_factory=list)
62-
63-
@property
64-
def active_campaigns(self) -> list[rules.CampaignConfig]:
65-
return [cc for cc in self.campaign_configs if cc.campaign_live]
66-
67-
def campaigns_grouped_by_condition_name(
68-
self, conditions: list[str], category: str
69-
) -> Iterator[tuple[eligibility_status.ConditionName, list[rules.CampaignConfig]]]:
70-
"""Generator that yields campaign groups filtered by condition names and campaign category."""
71-
72-
mapping = {
73-
"ALL": {"V", "S"},
74-
"VACCINATIONS": {"V"},
75-
"SCREENING": {"S"},
76-
}
77-
78-
allowed_types = mapping.get(category, set())
63+
campaign_evaluator: CampaignEvaluator = field(default_factory=CampaignEvaluator)
64+
person_data_reader: PersonDataReader = field(default_factory=PersonDataReader)
7965

80-
filter_all_conditions = "ALL" in conditions
81-
82-
for condition_name, campaign_group in groupby(
83-
sorted(self.active_campaigns, key=attrgetter("target")),
84-
key=attrgetter("target"),
85-
):
86-
campaigns = list(campaign_group)
87-
if campaigns[0].type in allowed_types and (filter_all_conditions or str(condition_name) in conditions):
88-
yield condition_name, campaigns
89-
90-
@property
91-
def person_cohorts(self) -> set[str]:
92-
cohorts_row: Mapping[str, dict[str, dict[str, dict[str, Any]]]] = next(
93-
(row for row in self.person_data if row.get("ATTRIBUTE_TYPE") == "COHORTS"),
94-
{},
95-
)
96-
return set(cohorts_row.get("COHORT_MAP", {}).get("cohorts", {}).get("M", {}).keys())
66+
results: list[eligibility_status.Condition] = field(default_factory=list)
9767

9868
@staticmethod
9969
def get_the_best_cohort_memberships(
@@ -164,7 +134,10 @@ def evaluate_eligibility(
164134
actions: list[SuggestedAction] | None = []
165135
action_rule_priority, action_rule_name = None, None
166136

167-
for condition_name, campaign_group in self.campaigns_grouped_by_condition_name(conditions, category):
137+
requested_grouped_campaigns = self.campaign_evaluator.get_requested_grouped_campaigns(
138+
self.campaign_configs, conditions, category
139+
)
140+
for condition_name, campaign_group in requested_grouped_campaigns:
168141
best_active_iteration: Iteration | None
169142
best_candidate: IterationResult
170143
best_campaign_id: CampaignID | None
@@ -289,7 +262,8 @@ def get_cohort_results(self, active_iteration: rules.Iteration) -> dict[str, Coh
289262
filter_rules, suppression_rules = self.get_rules_by_type(active_iteration)
290263
for cohort in sorted(active_iteration.iteration_cohorts, key=attrgetter("priority")):
291264
# Base Eligibility - check
292-
if cohort.cohort_label in self.person_cohorts or cohort.is_magic_cohort:
265+
person_cohorts = self.person_data_reader.get_person_cohorts(self.person_data)
266+
if cohort.cohort_label in person_cohorts or cohort.is_magic_cohort:
293267
# Eligibility - check
294268
if self.is_eligible_by_filter_rules(cohort, cohort_results, filter_rules):
295269
# Actionability - evaluation

src/eligibility_signposting_api/services/calculators/rule_calculator.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
from __future__ import annotations
22

33
from collections.abc import Collection, Mapping
4-
from dataclasses import dataclass
4+
from dataclasses import dataclass, field
55
from typing import Any
66

77
from hamcrest.core.string_description import StringDescription
88

99
from eligibility_signposting_api.model import eligibility_status, rules
10-
from eligibility_signposting_api.services.rules.operators import OperatorRegistry
10+
from eligibility_signposting_api.services.operators.operators import OperatorRegistry
11+
from eligibility_signposting_api.services.processors.person_data_reader import PersonDataReader
1112

1213
Row = Collection[Mapping[str, Any]]
1314

@@ -17,6 +18,8 @@ class RuleCalculator:
1718
person_data: Row
1819
rule: rules.IterationRule
1920

21+
person_data_reader: PersonDataReader = field(default_factory=PersonDataReader)
22+
2023
def evaluate_exclusion(self) -> tuple[eligibility_status.Status, eligibility_status.Reason]:
2124
"""Evaluate if a particular rule excludes this person. Return the result, and the reason for the result."""
2225
attribute_value = self.get_attribute_value()
@@ -43,15 +46,7 @@ def get_attribute_value(self) -> str | None:
4346
(r for r in self.person_data if r.get("ATTRIBUTE_TYPE", "") == "COHORTS"), None
4447
)
4548
if cohorts:
46-
attr_name = (
47-
"COHORT_MAP"
48-
if not self.rule.attribute_name or self.rule.attribute_name == "COHORT_LABEL"
49-
else self.rule.attribute_name
50-
)
51-
cohort_map = self.get_value(cohorts, attr_name)
52-
cohorts_dict = self.get_value(cohort_map, "cohorts")
53-
m_dict = self.get_value(cohorts_dict, "M")
54-
person_cohorts: set[str] = set(m_dict.keys())
49+
person_cohorts = self.person_data_reader.get_person_cohorts(self.person_data)
5550
attribute_value = ",".join(person_cohorts)
5651
else:
5752
attribute_value = None
@@ -66,11 +61,6 @@ def get_attribute_value(self) -> str | None:
6661
raise NotImplementedError(msg)
6762
return attribute_value
6863

69-
@staticmethod
70-
def get_value(dictionary: Mapping[str, Any] | None, key: str) -> dict:
71-
v = dictionary.get(key, {}) if isinstance(dictionary, dict) else {}
72-
return v if isinstance(v, dict) else {}
73-
7464
def evaluate_rule(self, attribute_value: str | None) -> tuple[eligibility_status.Status, str, bool]:
7565
"""Evaluate a rule against a person data attribute. Return the result, and the reason for the result."""
7666
matcher_class = OperatorRegistry.get(self.rule.operator)

src/eligibility_signposting_api/services/rules/__init__.py renamed to src/eligibility_signposting_api/services/operators/__init__.py

File renamed without changes.

src/eligibility_signposting_api/services/rules/operators.py renamed to src/eligibility_signposting_api/services/operators/operators.py

File renamed without changes.

src/eligibility_signposting_api/services/processors/__init__.py

Whitespace-only changes.
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from collections.abc import Collection, Iterator
2+
from itertools import groupby
3+
from operator import attrgetter
4+
5+
from wireup import service
6+
7+
from eligibility_signposting_api.model import eligibility_status, rules
8+
9+
10+
@service
11+
class CampaignEvaluator:
12+
"""Filters and groups campaign configurations."""
13+
14+
def get_active_campaigns(self, campaign_configs: Collection[rules.CampaignConfig]) -> list[rules.CampaignConfig]:
15+
return [cc for cc in campaign_configs if cc.campaign_live]
16+
17+
def get_requested_grouped_campaigns(
18+
self, campaign_configs: Collection[rules.CampaignConfig], conditions: list[str], category: str
19+
) -> Iterator[tuple[eligibility_status.ConditionName, list[rules.CampaignConfig]]]:
20+
mapping = {
21+
"ALL": {"V", "S"},
22+
"VACCINATIONS": {"V"},
23+
"SCREENING": {"S"},
24+
}
25+
26+
allowed_types = mapping.get(category, set())
27+
28+
filter_all_conditions = "ALL" in conditions
29+
30+
active_campaigns = self.get_active_campaigns(campaign_configs)
31+
32+
for condition_name, campaign_group in groupby(
33+
sorted(active_campaigns, key=attrgetter("target")),
34+
key=attrgetter("target"),
35+
):
36+
campaigns = list(campaign_group)
37+
if (
38+
campaigns
39+
and campaigns[0].type in allowed_types
40+
and (filter_all_conditions or str(condition_name) in conditions)
41+
):
42+
yield condition_name, campaigns
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Collection, Mapping
4+
from typing import Any
5+
6+
from wireup import service
7+
8+
Row = Collection[Mapping[str, Any]]
9+
10+
11+
@service
12+
class PersonDataReader:
13+
"""Handles extracting and interpreting person data."""
14+
15+
def get_person_cohorts(self, person_data: Row) -> set[str]:
16+
cohorts_row: Mapping[str, list[dict[str, str]]] = next(
17+
(row for row in person_data if row.get("ATTRIBUTE_TYPE") == "COHORTS"),
18+
{},
19+
)
20+
person_cohorts = set()
21+
22+
for membership in cohorts_row.get("COHORT_MEMBERSHIPS", []):
23+
if membership.get("COHORT_LABEL"):
24+
person_cohorts.add(membership.get("COHORT_LABEL"))
25+
26+
return person_cohorts

tests/fixtures/builders/repos/person.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,9 @@ def person_rows_builder( # noqa:PLR0913
6666
{
6767
"NHS_NUMBER": key,
6868
"ATTRIBUTE_TYPE": "COHORTS",
69-
"COHORT_MAP": {
70-
"cohorts": {
71-
"M": {
72-
cohort: {"M": {"dateJoined": {"S": faker.past_date().strftime("%Y%m%d")}}} for cohort in cohorts
73-
}
74-
}
75-
},
69+
"COHORT_MEMBERSHIPS": [
70+
{"COHORT_LABEL": cohort, "DATE_JOINED": faker.past_date().strftime("%Y%m%d")} for cohort in cohorts
71+
],
7672
},
7773
]
7874
rows.extend(

tests/unit/services/calculators/test_eligibility_calculator.py

Lines changed: 6 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from faker import Faker
66
from flask import Flask, g
77
from freezegun import freeze_time
8-
from hamcrest import assert_that, contains_exactly, contains_inanyorder, equal_to, has_item, has_items, is_, is_in
8+
from hamcrest import assert_that, contains_exactly, contains_inanyorder, equal_to, has_item, has_items, is_in
99
from pydantic import HttpUrl, ValidationError
1010

1111
from eligibility_signposting_api.audit.audit_models import AuditAction, AuditEvent
@@ -95,7 +95,8 @@ def test_not_base_eligible(faker: Faker):
9595
target="RSV",
9696
iterations=[
9797
rule_builder.IterationFactory.build(
98-
iteration_cohorts=[rule_builder.IterationCohortFactory.build(cohort_label="cohort2")]
98+
iteration_cohorts=[rule_builder.IterationCohortFactory.build(cohort_label="cohort2")],
99+
iteration_rules=[],
99100
)
100101
],
101102
)
@@ -175,6 +176,7 @@ def test_only_live_campaigns_considered(faker: Faker):
175176
iterations=[
176177
rule_builder.IterationFactory.build(
177178
iteration_cohorts=[rule_builder.IterationCohortFactory.build(cohort_label="cohort2")],
179+
iteration_rules=[],
178180
)
179181
],
180182
start_date=datetime.date(2025, 4, 20),
@@ -218,7 +220,7 @@ def test_campaigns_with_applicable_iteration_types_in_campaign_level_considered(
218220
# Given
219221
nhs_number = NHSNumber(faker.nhs_number())
220222

221-
person_rows = person_rows_builder(nhs_number)
223+
person_rows = person_rows_builder(nhs_number, cohorts=[])
222224
campaign_configs = [rule_builder.CampaignConfigFactory.build(target="RSV", iteration_type=iteration_type)]
223225

224226
calculator = EligibilityCalculator(person_rows, campaign_configs)
@@ -247,7 +249,7 @@ def test_campaigns_with_applicable_iteration_types_in_iteration_level_considered
247249
# Given
248250
nhs_number = NHSNumber(faker.nhs_number())
249251

250-
person_rows = person_rows_builder(nhs_number)
252+
person_rows = person_rows_builder(nhs_number, cohorts=[])
251253
campaign_configs = [
252254
rule_builder.CampaignConfigFactory.build(
253255
target="RSV", iterations=[rule_builder.IterationFactory.build(type=iteration_type)]
@@ -935,11 +937,6 @@ def test_status_on_target_based_on_last_successful_date(
935937
Status.not_eligible,
936938
"cohort label is the default attribute name for the cohort attribute level",
937939
),
938-
(
939-
rules.RuleAttributeName("LOCATION"),
940-
Status.actionable,
941-
"attribute name that is not cohort label",
942-
),
943940
],
944941
)
945942
def test_status_on_cohort_attribute_level(
@@ -2383,38 +2380,6 @@ def test_should_not_include_actions_when_include_actions_flag_is_false_when_stat
23832380
)
23842381

23852382

2386-
@pytest.mark.parametrize(
2387-
("campaign_target", "campaign_type", "conditions_filter", "category_filter", "expected_result"),
2388-
[
2389-
# Multiple matching campaigns under the same condition
2390-
("RSV", "V", ["RSV"], "VACCINATIONS", [("RSV", "V")]),
2391-
("RSV", "V", ["COVID"], "VACCINATIONS", []),
2392-
("RSV", "S", ["RSV"], "ALL", [("RSV", "S")]),
2393-
("RSV", "S", ["ALL"], "ALL", [("RSV", "S")]),
2394-
("RSV", "S", ["RSV"], "VACCINATIONS", []),
2395-
# Multiple campaigns with different types under the same condition name
2396-
("RSV", "V", ["RSV"], "ALL", [("RSV", "V")]),
2397-
# Campaign is live but condition not in filter (no yield)
2398-
("FLU", "V", ["COVID", "RSV"], "ALL", []),
2399-
# Category is ALL and condition filter includes ALL (everything matches)
2400-
("FLU", "S", ["ALL"], "ALL", [("FLU", "S")]),
2401-
# Condition filter is unknown (should not match anything)
2402-
("COVID", "V", ["UNKNOWN"], "VACCINATIONS", []),
2403-
# Campaign with the target matching one of several condition filters
2404-
("FLU", "V", ["COVID", "FLU"], "VACCINATIONS", [("FLU", "V")]),
2405-
],
2406-
)
2407-
def test_campaigns_grouped_by_condition_name_filters_correctly(
2408-
campaign_target, campaign_type, conditions_filter, category_filter, expected_result
2409-
):
2410-
campaign = rule_builder.CampaignConfigFactory.build(target=campaign_target, type=campaign_type, campaign_live=True)
2411-
2412-
calculator = EligibilityCalculator(person_data=[], campaign_configs=[campaign])
2413-
result = list(calculator.campaigns_grouped_by_condition_name(conditions_filter, category_filter))
2414-
2415-
assert_that([(str(name), group[0].type) for name, group in result], is_(expected_result))
2416-
2417-
24182383
@pytest.mark.parametrize(
24192384
(
24202385
"test_comment",

tests/unit/services/operators/test_operators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from hamcrest import assert_that, equal_to
44

55
from eligibility_signposting_api.model.rules import RuleOperator
6-
from eligibility_signposting_api.services.rules.operators import Operator, OperatorRegistry
6+
from eligibility_signposting_api.services.operators.operators import Operator, OperatorRegistry
77

88
# Test cases: person_data, rule_operator, rule_value, expected, test_comment
99
cases: list[tuple[str | None, RuleOperator, str | None, bool, str]] = []

0 commit comments

Comments
 (0)