Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from typing import TYPE_CHECKING, Any

from eligibility_signposting_api.audit.audit_context import AuditContext
from eligibility_signposting_api.services.processors.campaign_evaluator import CampaignEvaluator
from eligibility_signposting_api.services.processors.person_data_reader import PersonDataReader

if TYPE_CHECKING:
from eligibility_signposting_api.model.rules import (
Expand Down Expand Up @@ -58,42 +60,10 @@ class EligibilityCalculator:
person_data: Row
campaign_configs: Collection[rules.CampaignConfig]

results: list[eligibility_status.Condition] = field(default_factory=list)

@property
def active_campaigns(self) -> list[rules.CampaignConfig]:
return [cc for cc in self.campaign_configs if cc.campaign_live]

def campaigns_grouped_by_condition_name(
self, conditions: list[str], category: str
) -> Iterator[tuple[eligibility_status.ConditionName, list[rules.CampaignConfig]]]:
"""Generator that yields campaign groups filtered by condition names and campaign category."""

mapping = {
"ALL": {"V", "S"},
"VACCINATIONS": {"V"},
"SCREENING": {"S"},
}

allowed_types = mapping.get(category, set())
campaign_evaluator: CampaignEvaluator = field(default_factory=CampaignEvaluator)
person_data_reader: PersonDataReader = field(default_factory=PersonDataReader)

filter_all_conditions = "ALL" in conditions

for condition_name, campaign_group in groupby(
sorted(self.active_campaigns, key=attrgetter("target")),
key=attrgetter("target"),
):
campaigns = list(campaign_group)
if campaigns[0].type in allowed_types and (filter_all_conditions or str(condition_name) in conditions):
yield condition_name, campaigns

@property
def person_cohorts(self) -> set[str]:
cohorts_row: Mapping[str, dict[str, dict[str, dict[str, Any]]]] = next(
(row for row in self.person_data if row.get("ATTRIBUTE_TYPE") == "COHORTS"),
{},
)
return set(cohorts_row.get("COHORT_MAP", {}).get("cohorts", {}).get("M", {}).keys())
results: list[eligibility_status.Condition] = field(default_factory=list)

@staticmethod
def get_the_best_cohort_memberships(
Expand Down Expand Up @@ -164,7 +134,10 @@ def evaluate_eligibility(
actions: list[SuggestedAction] | None = []
action_rule_priority, action_rule_name = None, None

for condition_name, campaign_group in self.campaigns_grouped_by_condition_name(conditions, category):
requested_grouped_campaigns = self.campaign_evaluator.get_requested_grouped_campaigns(
self.campaign_configs, conditions, category
)
for condition_name, campaign_group in requested_grouped_campaigns:
best_active_iteration: Iteration | None
best_candidate: IterationResult
best_campaign_id: CampaignID | None
Expand Down Expand Up @@ -289,7 +262,8 @@ def get_cohort_results(self, active_iteration: rules.Iteration) -> dict[str, Coh
filter_rules, suppression_rules = self.get_rules_by_type(active_iteration)
for cohort in sorted(active_iteration.iteration_cohorts, key=attrgetter("priority")):
# Base Eligibility - check
if cohort.cohort_label in self.person_cohorts or cohort.is_magic_cohort:
person_cohorts = self.person_data_reader.get_person_cohorts(self.person_data)
if cohort.cohort_label in person_cohorts or cohort.is_magic_cohort:
# Eligibility - check
if self.is_eligible_by_filter_rules(cohort, cohort_results, filter_rules):
# Actionability - evaluation
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

from collections.abc import Collection, Mapping
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any

from hamcrest.core.string_description import StringDescription

from eligibility_signposting_api.model import eligibility_status, rules
from eligibility_signposting_api.services.rules.operators import OperatorRegistry
from eligibility_signposting_api.services.operators.operators import OperatorRegistry
from eligibility_signposting_api.services.processors.person_data_reader import PersonDataReader

Row = Collection[Mapping[str, Any]]

Expand All @@ -17,6 +18,8 @@ class RuleCalculator:
person_data: Row
rule: rules.IterationRule

person_data_reader: PersonDataReader = field(default_factory=PersonDataReader)

def evaluate_exclusion(self) -> tuple[eligibility_status.Status, eligibility_status.Reason]:
"""Evaluate if a particular rule excludes this person. Return the result, and the reason for the result."""
attribute_value = self.get_attribute_value()
Expand All @@ -43,15 +46,7 @@ def get_attribute_value(self) -> str | None:
(r for r in self.person_data if r.get("ATTRIBUTE_TYPE", "") == "COHORTS"), None
)
if cohorts:
attr_name = (
"COHORT_MAP"
if not self.rule.attribute_name or self.rule.attribute_name == "COHORT_LABEL"
else self.rule.attribute_name
)
cohort_map = self.get_value(cohorts, attr_name)
cohorts_dict = self.get_value(cohort_map, "cohorts")
m_dict = self.get_value(cohorts_dict, "M")
person_cohorts: set[str] = set(m_dict.keys())
person_cohorts = self.person_data_reader.get_person_cohorts(self.person_data)
attribute_value = ",".join(person_cohorts)
else:
attribute_value = None
Expand All @@ -66,11 +61,6 @@ def get_attribute_value(self) -> str | None:
raise NotImplementedError(msg)
return attribute_value

@staticmethod
def get_value(dictionary: Mapping[str, Any] | None, key: str) -> dict:
v = dictionary.get(key, {}) if isinstance(dictionary, dict) else {}
return v if isinstance(v, dict) else {}

def evaluate_rule(self, attribute_value: str | None) -> tuple[eligibility_status.Status, str, bool]:
"""Evaluate a rule against a person data attribute. Return the result, and the reason for the result."""
matcher_class = OperatorRegistry.get(self.rule.operator)
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from collections.abc import Collection, Iterator
from itertools import groupby
from operator import attrgetter

from wireup import service

from eligibility_signposting_api.model import eligibility_status, rules


@service
class CampaignEvaluator:
"""Filters and groups campaign configurations."""

def get_active_campaigns(self, campaign_configs: Collection[rules.CampaignConfig]) -> list[rules.CampaignConfig]:
return [cc for cc in campaign_configs if cc.campaign_live]

def get_requested_grouped_campaigns(
self, campaign_configs: Collection[rules.CampaignConfig], conditions: list[str], category: str
) -> Iterator[tuple[eligibility_status.ConditionName, list[rules.CampaignConfig]]]:
mapping = {
"ALL": {"V", "S"},
"VACCINATIONS": {"V"},
"SCREENING": {"S"},
}

allowed_types = mapping.get(category, set())

filter_all_conditions = "ALL" in conditions

active_campaigns = self.get_active_campaigns(campaign_configs)

for condition_name, campaign_group in groupby(
sorted(active_campaigns, key=attrgetter("target")),
key=attrgetter("target"),
):
campaigns = list(campaign_group)
if (
campaigns
and campaigns[0].type in allowed_types
and (filter_all_conditions or str(condition_name) in conditions)
):
yield condition_name, campaigns
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from __future__ import annotations

from collections.abc import Collection, Mapping
from typing import Any

from wireup import service

Row = Collection[Mapping[str, Any]]


@service
class PersonDataReader:
"""Handles extracting and interpreting person data."""

def get_person_cohorts(self, person_data: Row) -> set[str]:
cohorts_row: Mapping[str, list[dict[str, str]]] = next(
(row for row in person_data if row.get("ATTRIBUTE_TYPE") == "COHORTS"),
{},
)
person_cohorts = set()

for membership in cohorts_row.get("COHORT_MEMBERSHIPS", []):
if membership.get("COHORT_LABEL"):
person_cohorts.add(membership.get("COHORT_LABEL"))

return person_cohorts
10 changes: 3 additions & 7 deletions tests/fixtures/builders/repos/person.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,9 @@ def person_rows_builder( # noqa:PLR0913
{
"NHS_NUMBER": key,
"ATTRIBUTE_TYPE": "COHORTS",
"COHORT_MAP": {
"cohorts": {
"M": {
cohort: {"M": {"dateJoined": {"S": faker.past_date().strftime("%Y%m%d")}}} for cohort in cohorts
}
}
},
"COHORT_MEMBERSHIPS": [
{"COHORT_LABEL": cohort, "DATE_JOINED": faker.past_date().strftime("%Y%m%d")} for cohort in cohorts
],
},
]
rows.extend(
Expand Down
47 changes: 6 additions & 41 deletions tests/unit/services/calculators/test_eligibility_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from faker import Faker
from flask import Flask, g
from freezegun import freeze_time
from hamcrest import assert_that, contains_exactly, contains_inanyorder, equal_to, has_item, has_items, is_, is_in
from hamcrest import assert_that, contains_exactly, contains_inanyorder, equal_to, has_item, has_items, is_in
from pydantic import HttpUrl, ValidationError

from eligibility_signposting_api.audit.audit_models import AuditAction, AuditEvent
Expand Down Expand Up @@ -95,7 +95,8 @@ def test_not_base_eligible(faker: Faker):
target="RSV",
iterations=[
rule_builder.IterationFactory.build(
iteration_cohorts=[rule_builder.IterationCohortFactory.build(cohort_label="cohort2")]
iteration_cohorts=[rule_builder.IterationCohortFactory.build(cohort_label="cohort2")],
iteration_rules=[],
)
],
)
Expand Down Expand Up @@ -175,6 +176,7 @@ def test_only_live_campaigns_considered(faker: Faker):
iterations=[
rule_builder.IterationFactory.build(
iteration_cohorts=[rule_builder.IterationCohortFactory.build(cohort_label="cohort2")],
iteration_rules=[],
)
],
start_date=datetime.date(2025, 4, 20),
Expand Down Expand Up @@ -218,7 +220,7 @@ def test_campaigns_with_applicable_iteration_types_in_campaign_level_considered(
# Given
nhs_number = NHSNumber(faker.nhs_number())

person_rows = person_rows_builder(nhs_number)
person_rows = person_rows_builder(nhs_number, cohorts=[])
campaign_configs = [rule_builder.CampaignConfigFactory.build(target="RSV", iteration_type=iteration_type)]

calculator = EligibilityCalculator(person_rows, campaign_configs)
Expand Down Expand Up @@ -247,7 +249,7 @@ def test_campaigns_with_applicable_iteration_types_in_iteration_level_considered
# Given
nhs_number = NHSNumber(faker.nhs_number())

person_rows = person_rows_builder(nhs_number)
person_rows = person_rows_builder(nhs_number, cohorts=[])
campaign_configs = [
rule_builder.CampaignConfigFactory.build(
target="RSV", iterations=[rule_builder.IterationFactory.build(type=iteration_type)]
Expand Down Expand Up @@ -935,11 +937,6 @@ def test_status_on_target_based_on_last_successful_date(
Status.not_eligible,
"cohort label is the default attribute name for the cohort attribute level",
),
(
rules.RuleAttributeName("LOCATION"),
Status.actionable,
"attribute name that is not cohort label",
),
],
)
def test_status_on_cohort_attribute_level(
Expand Down Expand Up @@ -2383,38 +2380,6 @@ def test_should_not_include_actions_when_include_actions_flag_is_false_when_stat
)


@pytest.mark.parametrize(
("campaign_target", "campaign_type", "conditions_filter", "category_filter", "expected_result"),
[
# Multiple matching campaigns under the same condition
("RSV", "V", ["RSV"], "VACCINATIONS", [("RSV", "V")]),
("RSV", "V", ["COVID"], "VACCINATIONS", []),
("RSV", "S", ["RSV"], "ALL", [("RSV", "S")]),
("RSV", "S", ["ALL"], "ALL", [("RSV", "S")]),
("RSV", "S", ["RSV"], "VACCINATIONS", []),
# Multiple campaigns with different types under the same condition name
("RSV", "V", ["RSV"], "ALL", [("RSV", "V")]),
# Campaign is live but condition not in filter (no yield)
("FLU", "V", ["COVID", "RSV"], "ALL", []),
# Category is ALL and condition filter includes ALL (everything matches)
("FLU", "S", ["ALL"], "ALL", [("FLU", "S")]),
# Condition filter is unknown (should not match anything)
("COVID", "V", ["UNKNOWN"], "VACCINATIONS", []),
# Campaign with the target matching one of several condition filters
("FLU", "V", ["COVID", "FLU"], "VACCINATIONS", [("FLU", "V")]),
],
)
def test_campaigns_grouped_by_condition_name_filters_correctly(
campaign_target, campaign_type, conditions_filter, category_filter, expected_result
):
campaign = rule_builder.CampaignConfigFactory.build(target=campaign_target, type=campaign_type, campaign_live=True)

calculator = EligibilityCalculator(person_data=[], campaign_configs=[campaign])
result = list(calculator.campaigns_grouped_by_condition_name(conditions_filter, category_filter))

assert_that([(str(name), group[0].type) for name, group in result], is_(expected_result))


@pytest.mark.parametrize(
(
"test_comment",
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/services/operators/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from hamcrest import assert_that, equal_to

from eligibility_signposting_api.model.rules import RuleOperator
from eligibility_signposting_api.services.rules.operators import Operator, OperatorRegistry
from eligibility_signposting_api.services.operators.operators import Operator, OperatorRegistry

# Test cases: person_data, rule_operator, rule_value, expected, test_comment
cases: list[tuple[str | None, RuleOperator, str | None, bool, str]] = []
Expand Down
Empty file.
Loading