Skip to content

Commit 237fc6d

Browse files
committed
ELI-351: Extracts Person data class
1 parent 0419783 commit 237fc6d

11 files changed

Lines changed: 124 additions & 97 deletions

File tree

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from dataclasses import dataclass
2+
from typing import Any
3+
4+
5+
@dataclass
6+
class Person:
7+
data: list[dict[str, Any]]

src/eligibility_signposting_api/repos/person_repo.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from wireup import Inject, service
77

88
from eligibility_signposting_api.model.eligibility_status import NHSNumber
9+
from eligibility_signposting_api.model.person import Person
910
from eligibility_signposting_api.repos.exceptions import NotFoundError
1011

1112
logger = logging.getLogger(__name__)
@@ -35,7 +36,7 @@ def __init__(self, table: Annotated[Any, Inject(qualifier="person_table")]) -> N
3536
super().__init__()
3637
self.table = table
3738

38-
def get_eligibility_data(self, nhs_number: NHSNumber) -> list[dict[str, Any]]:
39+
def get_eligibility_data(self, nhs_number: NHSNumber) -> Person:
3940
response = self.table.query(KeyConditionExpression=Key("NHS_NUMBER").eq(nhs_number))
4041
logger.debug("response %r for %r", response, nhs_number, extra={"response": response, "nhs_number": nhs_number})
4142

@@ -44,4 +45,5 @@ def get_eligibility_data(self, nhs_number: NHSNumber) -> list[dict[str, Any]]:
4445
raise NotFoundError(message)
4546

4647
logger.debug("returning items %s", items, extra={"items": items})
47-
return items
48+
49+
return Person(data=items)

src/eligibility_signposting_api/services/calculators/eligibility_calculator.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22

33
from _operator import attrgetter
44
from collections import defaultdict
5-
from collections.abc import Collection, Iterable, Iterator, Mapping
65
from dataclasses import dataclass, field
76
from itertools import groupby
8-
from typing import Any
7+
from typing import TYPE_CHECKING
98

109
from wireup import service
1110

@@ -43,19 +42,22 @@
4342
from eligibility_signposting_api.services.processors.campaign_evaluator import CampaignEvaluator
4443
from eligibility_signposting_api.services.processors.person_data_reader import PersonDataReader
4544

46-
Row = Collection[Mapping[str, Any]]
45+
if TYPE_CHECKING:
46+
from collections.abc import Collection, Iterable, Iterator
47+
48+
from eligibility_signposting_api.model.person import Person
4749

4850

4951
@service
5052
class EligibilityCalculatorFactory:
5153
@staticmethod
52-
def get(person_data: Row, campaign_configs: Collection[CampaignConfig]) -> EligibilityCalculator:
53-
return EligibilityCalculator(person_data=person_data, campaign_configs=campaign_configs)
54+
def get(person: Person, campaign_configs: Collection[CampaignConfig]) -> EligibilityCalculator:
55+
return EligibilityCalculator(person=person, campaign_configs=campaign_configs)
5456

5557

5658
@dataclass
5759
class EligibilityCalculator:
58-
person_data: Row
60+
person: Person
5961
campaign_configs: Collection[CampaignConfig]
6062

6163
campaign_evaluator: CampaignEvaluator = field(default_factory=CampaignEvaluator)
@@ -238,7 +240,7 @@ def handle_action_rules(
238240
for _, rule_group in groupby(sorted_rules_by_priority, key=priority_getter):
239241
rule_group_list = list(rule_group)
240242
matcher_matched_list = [
241-
RuleCalculator(person_data=self.person_data, rule=rule).evaluate_exclusion()[1].matcher_matched
243+
RuleCalculator(person=self.person, rule=rule).evaluate_exclusion()[1].matcher_matched
242244
for rule in rule_group_list
243245
]
244246

@@ -258,7 +260,7 @@ def get_cohort_results(self, active_iteration: Iteration) -> dict[str, CohortGro
258260
filter_rules, suppression_rules = self.get_rules_by_type(active_iteration)
259261
for cohort in sorted(active_iteration.iteration_cohorts, key=attrgetter("priority")):
260262
# Base Eligibility - check
261-
person_cohorts = self.person_data_reader.get_person_cohorts(self.person_data)
263+
person_cohorts = self.person_data_reader.get_person_cohorts(self.person)
262264
if cohort.cohort_label in person_cohorts or cohort.is_magic_cohort:
263265
# Eligibility - check
264266
if self.is_eligible_by_filter_rules(cohort, cohort_results, filter_rules):
@@ -329,7 +331,7 @@ def is_eligible_by_filter_rules(
329331
if status.is_exclusion:
330332
if cohort.cohort_label is not None:
331333
cohort_results[cohort.cohort_label] = CohortGroupResult(
332-
(cohort.cohort_group),
334+
cohort.cohort_group,
333335
Status.not_eligible,
334336
[],
335337
cohort.negative_description,
@@ -383,7 +385,7 @@ def evaluate_rules_priority_group(
383385

384386
for rule in rules_group:
385387
is_rule_stop = rule.rule_stop or is_rule_stop
386-
rule_calculator = RuleCalculator(person_data=self.person_data, rule=rule)
388+
rule_calculator = RuleCalculator(person=self.person, rule=rule)
387389
status, reason = rule_calculator.evaluate_exclusion()
388390
if status.is_exclusion:
389391
best_status = eligibility_status.Status.best(status, best_status)

src/eligibility_signposting_api/services/calculators/rule_calculator.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from __future__ import annotations
22

3-
from collections.abc import Collection, Mapping
43
from dataclasses import dataclass, field
5-
from typing import Any
4+
from typing import TYPE_CHECKING
65

76
from hamcrest.core.string_description import StringDescription
87

@@ -11,12 +10,15 @@
1110
from eligibility_signposting_api.services.operators.operators import OperatorRegistry
1211
from eligibility_signposting_api.services.processors.person_data_reader import PersonDataReader
1312

14-
Row = Collection[Mapping[str, Any]]
13+
if TYPE_CHECKING:
14+
from collections.abc import Mapping
15+
16+
from eligibility_signposting_api.model.person import Person
1517

1618

1719
@dataclass
1820
class RuleCalculator:
19-
person_data: Row
21+
person: Person
2022
rule: IterationRule
2123

2224
person_data_reader: PersonDataReader = field(default_factory=PersonDataReader)
@@ -39,22 +41,22 @@ def get_attribute_value(self) -> str | None:
3941
match self.rule.attribute_level:
4042
case RuleAttributeLevel.PERSON:
4143
person: Mapping[str, str | None] | None = next(
42-
(r for r in self.person_data if r.get("ATTRIBUTE_TYPE", "") == "PERSON"), None
44+
(r for r in self.person.data if r.get("ATTRIBUTE_TYPE", "") == "PERSON"), None
4345
)
4446
attribute_value = person.get(str(self.rule.attribute_name)) if person else None
4547
case RuleAttributeLevel.COHORT:
4648
cohorts: Mapping[str, str | None] | None = next(
47-
(r for r in self.person_data if r.get("ATTRIBUTE_TYPE", "") == "COHORTS"), None
49+
(r for r in self.person.data if r.get("ATTRIBUTE_TYPE", "") == "COHORTS"), None
4850
)
4951
if cohorts:
50-
person_cohorts = self.person_data_reader.get_person_cohorts(self.person_data)
52+
person_cohorts = self.person_data_reader.get_person_cohorts(self.person)
5153
attribute_value = ",".join(person_cohorts)
5254
else:
5355
attribute_value = None
5456

5557
case RuleAttributeLevel.TARGET:
5658
target: Mapping[str, str | None] | None = next(
57-
(r for r in self.person_data if r.get("ATTRIBUTE_TYPE", "") == self.rule.attribute_target), None
59+
(r for r in self.person.data if r.get("ATTRIBUTE_TYPE", "") == self.rule.attribute_target), None
5860
)
5961
attribute_value = target.get(str(self.rule.attribute_name)) if target else None
6062
case _: # pragma: no cover
Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,25 @@
11
from __future__ import annotations
22

3-
from collections.abc import Collection, Mapping
4-
from typing import Any
5-
63
from wireup import service
74

8-
Row = Collection[Mapping[str, Any]]
5+
from eligibility_signposting_api.model.person import Person
96

107

118
@service
129
class PersonDataReader:
1310
"""Handles extracting and interpreting person data."""
1411

15-
def get_person_cohorts(self, person_data: Row) -> set[str]:
16-
cohorts_row: Mapping[str, list[dict[str, str]]] = next(
17-
(row for row in person_data if row.get("ATTRIBUTE_TYPE") == "COHORTS"),
18-
{},
19-
)
12+
def get_person_cohorts(self, person: Person) -> set[str]:
13+
cohorts_row: Person = Person([])
14+
for data in person.data:
15+
if data.get("ATTRIBUTE_TYPE") == "COHORTS":
16+
cohorts_row.data.append(data)
17+
2018
person_cohorts = set()
2119

22-
for membership in cohorts_row.get("COHORT_MEMBERSHIPS", []):
23-
if membership.get("COHORT_LABEL"):
24-
person_cohorts.add(membership.get("COHORT_LABEL"))
20+
if cohorts_row.data:
21+
for membership in cohorts_row.data[0].get("COHORT_MEMBERSHIPS", []):
22+
if membership.get("COHORT_LABEL"):
23+
person_cohorts.add(membership.get("COHORT_LABEL"))
2524

2625
return person_cohorts

tests/fixtures/builders/repos/person.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from faker import Faker
77

8+
from eligibility_signposting_api.model.person import Person
89
from tests.conftest import PersonDetailProvider
910

1011
Gender = Literal["0", "1", "2", "9"] # 0 - Not known, 1- Male, 2 - Female, 9 - Not specified. I know, right?
@@ -27,7 +28,7 @@ def person_rows_builder( # noqa:PLR0913
2728
de: bool | None = ...,
2829
msoa: str | None = ...,
2930
lsoa: str | None = ...,
30-
) -> list[dict[str, Any]]:
31+
) -> Person:
3132
faker = Faker("en_UK")
3233
faker.add_provider(PersonDetailProvider)
3334

@@ -85,4 +86,5 @@ def person_rows_builder( # noqa:PLR0913
8586
)
8687

8788
shuffle(rows)
88-
return rows
89+
90+
return Person(data=rows)

tests/integration/conftest.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ def persisted_person(person_table: Any, faker: Faker) -> Generator[eligibility_s
336336
date_of_birth = eligibility_status.DateOfBirth(faker.date_of_birth(minimum_age=18, maximum_age=65))
337337

338338
for row in (
339-
rows := person_rows_builder(nhs_number, date_of_birth=date_of_birth, postcode="hp1", cohorts=["cohort1"])
339+
rows := person_rows_builder(nhs_number, date_of_birth=date_of_birth, postcode="hp1", cohorts=["cohort1"]).data
340340
):
341341
person_table.put_item(Item=row)
342342

@@ -357,7 +357,7 @@ def persisted_77yo_person(person_table: Any, faker: Faker) -> Generator[eligibil
357357
date_of_birth=date_of_birth,
358358
postcode="hp1",
359359
cohorts=["cohort1", "cohort2"],
360-
)
360+
).data
361361
):
362362
person_table.put_item(Item=row)
363363

@@ -379,7 +379,7 @@ def persisted_person_all_cohorts(person_table: Any, faker: Faker) -> Generator[e
379379
postcode="hp1",
380380
cohorts=["cohort_label1", "cohort_label2", "cohort_label3"],
381381
icb="QE1",
382-
)
382+
).data
383383
):
384384
person_table.put_item(Item=row)
385385

@@ -393,7 +393,7 @@ def persisted_person_all_cohorts(person_table: Any, faker: Faker) -> Generator[e
393393
def persisted_person_no_cohorts(person_table: Any, faker: Faker) -> Generator[eligibility_status.NHSNumber]:
394394
nhs_number = eligibility_status.NHSNumber(faker.nhs_number())
395395

396-
for row in (rows := person_rows_builder(nhs_number)):
396+
for row in (rows := person_rows_builder(nhs_number).data):
397397
person_table.put_item(Item=row)
398398

399399
yield nhs_number
@@ -407,7 +407,7 @@ def persisted_person_pc_sw19(person_table: Any, faker: Faker) -> Generator[eligi
407407
nhs_number = eligibility_status.NHSNumber(
408408
faker.nhs_number(),
409409
)
410-
for row in (rows := person_rows_builder(nhs_number, postcode="SW19", cohorts=["cohort1"])):
410+
for row in (rows := person_rows_builder(nhs_number, postcode="SW19", cohorts=["cohort1"]).data):
411411
person_table.put_item(Item=row)
412412

413413
yield nhs_number

tests/integration/repo/test_person_repo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def test_person_found(person_table: Any, persisted_person: NHSNumber):
1818

1919
# Then
2020
assert_that(
21-
actual,
21+
actual.data,
2222
contains_inanyorder(
2323
has_entries({"NHS_NUMBER": persisted_person, "ATTRIBUTE_TYPE": "PERSON"}),
2424
has_entries({"NHS_NUMBER": persisted_person, "ATTRIBUTE_TYPE": "COHORTS"}),

tests/unit/services/calculators/test_eligibility_calculator.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
UrlLabel,
4141
UrlLink,
4242
)
43+
from eligibility_signposting_api.model.person import Person
4344
from eligibility_signposting_api.services.calculators.eligibility_calculator import EligibilityCalculator
4445
from tests.fixtures.builders.model import rule as rule_builder
4546
from tests.fixtures.builders.repos.person import person_rows_builder
@@ -956,12 +957,12 @@ def test_status_on_cohort_attribute_level(
956957
# Given
957958
nhs_number = NHSNumber(faker.nhs_number())
958959

959-
person_row: list[dict[str, Any]] = person_rows_builder(
960-
nhs_number, cohorts=["cohort1", "covid_eligibility_complaint_list"]
961-
)
962-
person_row_with_extra_items_in_cohort_row = [
963-
{**r, "LOCATION": "HP1"} for r in person_row if r.get("ATTRIBUTE_TYPE", "") == "COHORTS"
964-
]
960+
person_row: Person = person_rows_builder(nhs_number, cohorts=["cohort1", "covid_eligibility_complaint_list"])
961+
962+
person_row_with_extra_items_in_cohort_row = Person(person_row.data)
963+
for row in person_row_with_extra_items_in_cohort_row.data:
964+
if row.get("ATTRIBUTE_TYPE", "") == "COHORTS":
965+
row["LOCATION"] = "HP1"
965966

966967
campaign_configs = [
967968
rule_builder.CampaignConfigFactory.build(

tests/unit/services/calculators/test_rule_calculator.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,25 @@
1-
from collections.abc import Collection, Mapping
2-
from typing import Any
3-
41
import pytest
52

63
from eligibility_signposting_api.model.campaign_config import IterationRule, RuleAttributeLevel
4+
from eligibility_signposting_api.model.person import Person
75
from eligibility_signposting_api.services.calculators.rule_calculator import RuleCalculator
86
from tests.fixtures.builders.model import rule as rule_builder
97

10-
Row = Collection[Mapping[str, Any]]
11-
128

139
@pytest.mark.parametrize(
1410
("person_data", "rule", "expected"),
1511
[
1612
# PERSON attribute level
1713
(
18-
[{"ATTRIBUTE_TYPE": "PERSON", "POSTCODE": "SW19"}],
14+
Person([{"ATTRIBUTE_TYPE": "PERSON", "POSTCODE": "SW19"}]),
1915
rule_builder.IterationRuleFactory.build(
2016
attribute_level=RuleAttributeLevel.PERSON, attribute_name="POSTCODE"
2117
),
2218
"SW19",
2319
),
2420
# TARGET attribute level
2521
(
26-
[{"ATTRIBUTE_TYPE": "RSV", "LAST_SUCCESSFUL_DATE": "20240101"}],
22+
Person([{"ATTRIBUTE_TYPE": "RSV", "LAST_SUCCESSFUL_DATE": "20240101"}]),
2723
rule_builder.IterationRuleFactory.build(
2824
attribute_level=RuleAttributeLevel.TARGET,
2925
attribute_name="LAST_SUCCESSFUL_DATE",
@@ -33,17 +29,17 @@
3329
),
3430
# COHORT attribute level
3531
(
36-
[{"ATTRIBUTE_TYPE": "COHORTS", "COHORT_LABEL": ""}],
32+
Person([{"ATTRIBUTE_TYPE": "COHORTS", "COHORT_LABEL": ""}]),
3733
rule_builder.IterationRuleFactory.build(
3834
attribute_level=RuleAttributeLevel.COHORT, attribute_name="COHORT_LABEL"
3935
),
4036
"",
4137
),
4238
],
4339
)
44-
def test_get_attribute_value_for_all_attribute_levels(person_data: Row, rule: IterationRule, expected: str):
40+
def test_get_attribute_value_for_all_attribute_levels(person_data: Person, rule: IterationRule, expected: str):
4541
# Given
46-
calc = RuleCalculator(person_data=person_data, rule=rule)
42+
calc = RuleCalculator(person=person_data, rule=rule)
4743
# When
4844
actual = calc.get_attribute_value()
4945
# Then

0 commit comments

Comments
 (0)