Skip to content

Commit af8fef9

Browse files
cohort grouping & magic cohorts - extensive coverage (#145)
* single description for each cohort group - test, code refactored * cohort group description based on priority * refer status of magic cohorts based on other cohorts status * exclude cohort when description is "" or None * corrected magic cohorts, made cohort code not null * integration tests for no rule-text, no cohort description * cohort deduplication * actionable rule reasons collection test, code cleanup * code cleanup * code cleanup * lint issues
1 parent b5374f9 commit af8fef9

18 files changed

Lines changed: 1418 additions & 294 deletions

File tree

src/eligibility_signposting_api/app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from mangum.types import LambdaContext, LambdaEvent
99

1010
from eligibility_signposting_api import repos, services
11-
from eligibility_signposting_api.config import config, init_logging
11+
from eligibility_signposting_api.config.config import config, init_logging
1212
from eligibility_signposting_api.error_handler import handle_exception
1313
from eligibility_signposting_api.views import eligibility_blueprint
1414

src/eligibility_signposting_api/config/__init__.py

Whitespace-only changes.
File renamed without changes.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
MAGIC_COHORT_LABEL = "elid_all_people"

src/eligibility_signposting_api/model/eligibility.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
ConditionName = NewType("ConditionName", str)
1313

1414
RuleName = NewType("RuleName", str)
15-
RuleResult = NewType("RuleResult", str)
15+
RuleDescription = NewType("RuleDescription", str)
1616

1717

1818
class RuleType(StrEnum):
@@ -60,28 +60,28 @@ def best(*statuses: Status) -> Status:
6060
class Reason:
6161
rule_type: RuleType
6262
rule_name: RuleName
63-
rule_result: RuleResult
63+
rule_description: RuleDescription | None
6464

6565

6666
@dataclass
6767
class Condition:
6868
condition_name: ConditionName
6969
status: Status
70-
cohort_results: list[CohortResult]
70+
cohort_results: list[CohortGroupResult]
7171

7272

7373
@dataclass
74-
class CohortResult:
74+
class CohortGroupResult:
7575
cohort_code: str
7676
status: Status
7777
reasons: list[Reason]
78-
description: str
78+
description: str | None
7979

8080

8181
@dataclass
8282
class IterationResult:
8383
status: Status
84-
cohort_results: list[CohortResult]
84+
cohort_results: list[CohortGroupResult]
8585

8686

8787
@dataclass

src/eligibility_signposting_api/model/rules.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
from pydantic import BaseModel, Field, field_serializer, field_validator, model_validator
1212

13+
from eligibility_signposting_api.config.contants import MAGIC_COHORT_LABEL
14+
1315
if typing.TYPE_CHECKING: # pragma: no cover
1416
from pydantic import SerializationInfo
1517

@@ -85,8 +87,8 @@ class RuleAttributeLevel(StrEnum):
8587

8688

8789
class IterationCohort(BaseModel):
88-
cohort_label: CohortLabel | None = Field(None, alias="CohortLabel")
89-
cohort_group: CohortGroup | None = Field(None, alias="CohortGroup")
90+
cohort_label: CohortLabel = Field(alias="CohortLabel")
91+
cohort_group: CohortGroup = Field(alias="CohortGroup")
9092
positive_description: Description | None = Field(None, alias="PositiveDescription")
9193
negative_description: Description | None = Field(None, alias="NegativeDescription")
9294
priority: int | None = Field(None, alias="Priority")
@@ -107,14 +109,14 @@ class IterationRule(BaseModel):
107109
attribute_target: RuleAttributeTarget | None = Field(None, alias="AttributeTarget")
108110
rule_stop: RuleStop = Field(RuleStop(False), alias="RuleStop") # noqa: FBT003
109111

112+
model_config = {"populate_by_name": True, "extra": "ignore"}
113+
110114
@field_validator("rule_stop", mode="before")
111115
def parse_yn_to_bool(cls, v: str | bool) -> bool: # noqa: N805
112116
if isinstance(v, str):
113117
return v.upper() == "Y"
114118
return v
115119

116-
model_config = {"populate_by_name": True, "extra": "ignore"}
117-
118120

119121
class Iteration(BaseModel):
120122
id: IterationID = Field(..., alias="ID")
@@ -142,6 +144,17 @@ def parse_dates(cls, v: str | date) -> date:
142144
def serialize_dates(v: date, _info: SerializationInfo) -> str:
143145
return v.strftime("%Y%m%d")
144146

147+
@cached_property
148+
def has_magic_cohort(self) -> bool:
149+
return next(
150+
(
151+
True
152+
for cc in self.iteration_cohorts
153+
if cc.cohort_label and cc.cohort_label.upper() == MAGIC_COHORT_LABEL.upper()
154+
),
155+
False,
156+
)
157+
145158

146159
class CampaignConfig(BaseModel):
147160
id: CampaignID = Field(..., alias="ID")

src/eligibility_signposting_api/repos/factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from wireup import Inject, service
88
from yarl import URL
99

10-
from eligibility_signposting_api.config import AwsAccessKey, AwsRegion, AwsSecretAccessKey
10+
from eligibility_signposting_api.config.config import AwsAccessKey, AwsRegion, AwsSecretAccessKey
1111

1212
logger = logging.getLogger(__name__)
1313

src/eligibility_signposting_api/services/calculators/eligibility_calculator.py

Lines changed: 81 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from _operator import attrgetter
4+
from collections import defaultdict
45
from collections.abc import Collection, Iterable, Iterator, Mapping
56
from dataclasses import dataclass, field
67
from itertools import groupby
@@ -13,16 +14,17 @@
1314

1415
from eligibility_signposting_api.model import eligibility, rules
1516
from eligibility_signposting_api.model.eligibility import (
16-
CohortResult,
17+
CohortGroupResult,
1718
Condition,
1819
ConditionName,
1920
IterationResult,
2021
Status,
2122
)
22-
from eligibility_signposting_api.services.calculators.rule_calculator import RuleCalculator
23+
from eligibility_signposting_api.services.calculators.rule_calculator import (
24+
RuleCalculator,
25+
)
2326

2427
Row = Collection[Mapping[str, Any]]
25-
magic_cohort = "elid_all_people"
2628

2729

2830
@service
@@ -49,23 +51,39 @@ def campaigns_grouped_by_condition_name(
4951
) -> Iterator[tuple[eligibility.ConditionName, list[rules.CampaignConfig]]]:
5052
"""Generator function to iterate over campaign groups by condition name."""
5153
for condition_name, campaign_group in groupby(
52-
sorted(self.active_campaigns, key=attrgetter("target")), key=attrgetter("target")
54+
sorted(self.active_campaigns, key=attrgetter("target")),
55+
key=attrgetter("target"),
5356
):
5457
yield condition_name, list(campaign_group)
5558

5659
@property
5760
def person_cohorts(self) -> set[str]:
5861
cohorts_row: Mapping[str, dict[str, dict[str, dict[str, Any]]]] = next(
59-
(row for row in self.person_data if row.get("ATTRIBUTE_TYPE") == "COHORTS"), {}
62+
(row for row in self.person_data if row.get("ATTRIBUTE_TYPE") == "COHORTS"),
63+
{},
6064
)
6165
return set(cohorts_row.get("COHORT_MAP", {}).get("cohorts", {}).get("M", {}).keys())
6266

6367
@staticmethod
64-
def get_best_cohort(cohort_results: dict[str, CohortResult]) -> tuple[Status, list[CohortResult]]:
68+
def get_the_best_cohort_memberships(
69+
cohort_results: dict[str, CohortGroupResult],
70+
) -> tuple[Status, list[CohortGroupResult]]:
6571
if not cohort_results:
6672
return eligibility.Status.not_eligible, []
73+
6774
best_status = eligibility.Status.best(*[result.status for result in cohort_results.values()])
6875
best_cohorts = [result for result in cohort_results.values() if result.status == best_status]
76+
77+
best_cohorts = [
78+
CohortGroupResult(
79+
cohort_code=cc.cohort_code,
80+
status=cc.status,
81+
reasons=cc.reasons,
82+
description=(cc.description or "").strip() if cc.description else "",
83+
)
84+
for cc in best_cohorts
85+
]
86+
6987
return best_status, best_cohorts
7088

7189
@staticmethod
@@ -98,28 +116,30 @@ def evaluate_eligibility(self) -> eligibility.EligibilityStatus:
98116
iteration_results: dict[str, IterationResult] = {}
99117

100118
for active_iteration in [cc.current_iteration for cc in campaign_group]:
101-
cohort_results: dict[str, CohortResult] = {}
119+
cohort_results: dict[str, CohortGroupResult] = {}
102120

103121
filter_rules, suppression_rules = self.get_rules_by_type(active_iteration)
122+
104123
for cohort in sorted(active_iteration.iteration_cohorts, key=attrgetter("priority")):
105124
# Base Eligibility - check
106-
if cohort.cohort_label in self.person_cohorts or cohort.cohort_label == magic_cohort:
125+
if cohort.cohort_label in self.person_cohorts or active_iteration.has_magic_cohort:
107126
# Eligibility - check
108127
if self.is_eligible_by_filter_rules(cohort, cohort_results, filter_rules):
109128
# Actionability - evaluation
110129
self.evaluate_suppression_rules(cohort, cohort_results, suppression_rules)
111130

112131
# Not base eligible
113132
elif cohort.cohort_label is not None:
114-
cohort_results[cohort.cohort_label] = CohortResult(
115-
cohort.cohort_group if cohort.cohort_group else cohort.cohort_label,
133+
cohort_results[cohort.cohort_label] = CohortGroupResult(
134+
(cohort.cohort_group),
116135
Status.not_eligible,
117136
[],
118-
str(cohort.negative_description),
137+
cohort.negative_description,
119138
)
120139

121140
# Determine Result between cohorts - get the best
122-
status, best_cohorts = self.get_best_cohort(cohort_results)
141+
status, best_cohorts = self.get_the_best_cohort_memberships(cohort_results)
142+
123143
iteration_results[active_iteration.name] = IterationResult(status, best_cohorts)
124144

125145
# Determine results between iterations - get the best
@@ -130,20 +150,50 @@ def evaluate_eligibility(self) -> eligibility.EligibilityStatus:
130150
condition_results[condition_name] = best_candidate
131151

132152
# Consolidate all the results and return
133-
final_result = [
134-
Condition(
135-
condition_name=condition_name,
136-
status=active_iteration_result.status,
137-
cohort_results=active_iteration_result.cohort_results,
138-
)
139-
for condition_name, active_iteration_result in condition_results.items()
140-
]
153+
final_result = self.build_condition_results(condition_results)
141154
return eligibility.EligibilityStatus(conditions=final_result)
142155

156+
@staticmethod
157+
def build_condition_results(
158+
condition_results: dict[ConditionName, IterationResult],
159+
) -> list[Condition]:
160+
conditions: list[Condition] = []
161+
# iterate over conditions
162+
for condition_name, active_iteration_result in condition_results.items():
163+
grouped_cohort_results = defaultdict(list)
164+
# iterate over cohorts and group them by status and cohort_group
165+
for cohort_result in active_iteration_result.cohort_results:
166+
if active_iteration_result.status == cohort_result.status:
167+
grouped_cohort_results[cohort_result.cohort_code].append(cohort_result)
168+
169+
# deduplicate grouped cohort results by cohort_code
170+
deduplicated_cohort_results = [
171+
CohortGroupResult(
172+
cohort_code=group_cohort_code,
173+
status=group[0].status,
174+
# Flatten all reasons from the group
175+
reasons=[reason for cohort in group for reason in cohort.reasons],
176+
# get the first nonempty description
177+
description=next((c.description for c in group if c.description), group[0].description),
178+
)
179+
for group_cohort_code, group in grouped_cohort_results.items()
180+
if group
181+
]
182+
183+
# return condition with cohort results
184+
conditions.append(
185+
Condition(
186+
condition_name=condition_name,
187+
status=active_iteration_result.status,
188+
cohort_results=list(deduplicated_cohort_results),
189+
)
190+
)
191+
return conditions
192+
143193
def is_eligible_by_filter_rules(
144194
self,
145195
cohort: IterationCohort,
146-
cohort_results: dict[str, CohortResult],
196+
cohort_results: dict[str, CohortGroupResult],
147197
filter_rules: Iterable[rules.IterationRule],
148198
) -> bool:
149199
is_eligible = True
@@ -156,11 +206,11 @@ def is_eligible_by_filter_rules(
156206
)
157207
if status.is_exclusion:
158208
if cohort.cohort_label is not None:
159-
cohort_results[str(cohort.cohort_label)] = CohortResult(
160-
cohort.cohort_group if cohort.cohort_group else cohort.cohort_label,
209+
cohort_results[cohort.cohort_label] = CohortGroupResult(
210+
(cohort.cohort_group),
161211
Status.not_eligible,
162212
[],
163-
str(cohort.negative_description),
213+
cohort.negative_description,
164214
)
165215
is_eligible = False
166216
break
@@ -169,7 +219,7 @@ def is_eligible_by_filter_rules(
169219
def evaluate_suppression_rules(
170220
self,
171221
cohort: IterationCohort,
172-
cohort_results: dict[str, CohortResult],
222+
cohort_results: dict[str, CohortGroupResult],
173223
suppression_rules: Iterable[rules.IterationRule],
174224
) -> None:
175225
is_actionable: bool = True
@@ -191,18 +241,18 @@ def evaluate_suppression_rules(
191241
if cohort.cohort_label is not None:
192242
key = cohort.cohort_label
193243
if is_actionable:
194-
cohort_results[key] = CohortResult(
195-
cohort.cohort_group if cohort.cohort_group else key,
244+
cohort_results[key] = CohortGroupResult(
245+
cohort.cohort_group,
196246
Status.actionable,
197247
[],
198-
str(cohort.positive_description),
248+
cohort.positive_description,
199249
)
200250
else:
201-
cohort_results[key] = CohortResult(
202-
cohort.cohort_group if cohort.cohort_group else key,
251+
cohort_results[key] = CohortGroupResult(
252+
cohort.cohort_group,
203253
Status.not_actionable,
204254
suppression_reasons,
205-
str(cohort.positive_description),
255+
cohort.positive_description,
206256
)
207257

208258
def evaluate_rules_priority_group(

src/eligibility_signposting_api/services/calculators/rule_calculator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def evaluate_exclusion(self) -> tuple[eligibility.Status, eligibility.Reason]:
2424
reason = eligibility.Reason(
2525
rule_name=eligibility.RuleName(self.rule.name),
2626
rule_type=eligibility.RuleType(self.rule.type),
27-
rule_result=eligibility.RuleResult(self.rule.description),
27+
rule_description=eligibility.RuleDescription(self.rule.description),
2828
)
2929
return status, reason
3030

src/eligibility_signposting_api/views/eligibility.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import logging
22
import uuid
3-
from collections import defaultdict
43
from datetime import UTC, datetime
54
from http import HTTPStatus
65

@@ -72,24 +71,19 @@ def build_eligibility_response(eligibility_status: EligibilityStatus) -> eligibi
7271
def build_eligibility_cohorts(condition: Condition) -> list[eligibility.EligibilityCohort]:
7372
"""Group Iteration cohorts and make only one entry per cohort group"""
7473

75-
grouped_cohort_results = defaultdict(list)
76-
77-
for cohort_result in condition.cohort_results:
78-
if condition.status == cohort_result.status:
79-
grouped_cohort_results[cohort_result.cohort_code].append(cohort_result)
80-
8174
return [
8275
eligibility.EligibilityCohort(
83-
cohortCode=cohort_group_code,
84-
cohortText=cohort_group[0].description,
85-
cohortStatus=STATUS_MAPPING[cohort_group[0].status],
76+
cohortCode=eligibility.CohortCode(cohort_result.cohort_code),
77+
cohortText=eligibility.CohortText(cohort_result.description),
78+
cohortStatus=STATUS_MAPPING[cohort_result.status],
8679
)
87-
for cohort_group_code, cohort_group in grouped_cohort_results.items()
88-
if cohort_group
80+
for cohort_result in condition.cohort_results
81+
if cohort_result and condition.status == cohort_result.status and cohort_result.description
8982
]
9083

9184

9285
def build_suitability_results(condition: Condition) -> list[eligibility.SuitabilityRule]:
86+
"""Make only one entry if there are duplicate rules"""
9387
if condition.status != Status.not_actionable:
9488
return []
9589

@@ -99,13 +93,13 @@ def build_suitability_results(condition: Condition) -> list[eligibility.Suitabil
9993
for cohort_result in condition.cohort_results:
10094
if cohort_result.status == Status.not_actionable:
10195
for reason in cohort_result.reasons:
102-
if reason.rule_name not in unique_rule_codes:
96+
if reason.rule_name not in unique_rule_codes and reason.rule_description:
10397
unique_rule_codes.add(reason.rule_name)
10498
suitability_results.append(
10599
eligibility.SuitabilityRule(
106100
ruleType=eligibility.RuleType(reason.rule_type.value),
107101
ruleCode=eligibility.RuleCode(reason.rule_name),
108-
ruleText=eligibility.RuleText(reason.rule_result),
102+
ruleText=eligibility.RuleText(reason.rule_description),
109103
)
110104
)
111105

0 commit comments

Comments
 (0)