diff --git a/src/eligibility_signposting_api/services/calculators/eligibility_calculator.py b/src/eligibility_signposting_api/services/calculators/eligibility_calculator.py index a2e4f6bd8..e73bd39dc 100644 --- a/src/eligibility_signposting_api/services/calculators/eligibility_calculator.py +++ b/src/eligibility_signposting_api/services/calculators/eligibility_calculator.py @@ -8,6 +8,8 @@ from typing import TYPE_CHECKING, Any from eligibility_signposting_api.audit.audit_context import AuditContext +from eligibility_signposting_api.services.processors.campaign_evaluator import CampaignEvaluator +from eligibility_signposting_api.services.processors.person_data_reader import PersonDataReader if TYPE_CHECKING: from eligibility_signposting_api.model.rules import ( @@ -58,42 +60,10 @@ class EligibilityCalculator: person_data: Row campaign_configs: Collection[rules.CampaignConfig] - results: list[eligibility_status.Condition] = field(default_factory=list) - - @property - def active_campaigns(self) -> list[rules.CampaignConfig]: - return [cc for cc in self.campaign_configs if cc.campaign_live] - - def campaigns_grouped_by_condition_name( - self, conditions: list[str], category: str - ) -> Iterator[tuple[eligibility_status.ConditionName, list[rules.CampaignConfig]]]: - """Generator that yields campaign groups filtered by condition names and campaign category.""" - - mapping = { - "ALL": {"V", "S"}, - "VACCINATIONS": {"V"}, - "SCREENING": {"S"}, - } - - allowed_types = mapping.get(category, set()) + campaign_evaluator: CampaignEvaluator = field(default_factory=CampaignEvaluator) + person_data_reader: PersonDataReader = field(default_factory=PersonDataReader) - filter_all_conditions = "ALL" in conditions - - for condition_name, campaign_group in groupby( - sorted(self.active_campaigns, key=attrgetter("target")), - key=attrgetter("target"), - ): - campaigns = list(campaign_group) - if campaigns[0].type in allowed_types and (filter_all_conditions or str(condition_name) in conditions): - yield condition_name, campaigns - - @property - def person_cohorts(self) -> set[str]: - cohorts_row: Mapping[str, dict[str, dict[str, dict[str, Any]]]] = next( - (row for row in self.person_data if row.get("ATTRIBUTE_TYPE") == "COHORTS"), - {}, - ) - return set(cohorts_row.get("COHORT_MAP", {}).get("cohorts", {}).get("M", {}).keys()) + results: list[eligibility_status.Condition] = field(default_factory=list) @staticmethod def get_the_best_cohort_memberships( @@ -164,7 +134,10 @@ def evaluate_eligibility( actions: list[SuggestedAction] | None = [] action_rule_priority, action_rule_name = None, None - for condition_name, campaign_group in self.campaigns_grouped_by_condition_name(conditions, category): + requested_grouped_campaigns = self.campaign_evaluator.get_requested_grouped_campaigns( + self.campaign_configs, conditions, category + ) + for condition_name, campaign_group in requested_grouped_campaigns: best_active_iteration: Iteration | None best_candidate: IterationResult best_campaign_id: CampaignID | None @@ -289,7 +262,8 @@ def get_cohort_results(self, active_iteration: rules.Iteration) -> dict[str, Coh filter_rules, suppression_rules = self.get_rules_by_type(active_iteration) for cohort in sorted(active_iteration.iteration_cohorts, key=attrgetter("priority")): # Base Eligibility - check - if cohort.cohort_label in self.person_cohorts or cohort.is_magic_cohort: + person_cohorts = self.person_data_reader.get_person_cohorts(self.person_data) + if cohort.cohort_label in person_cohorts or cohort.is_magic_cohort: # Eligibility - check if self.is_eligible_by_filter_rules(cohort, cohort_results, filter_rules): # Actionability - evaluation diff --git a/src/eligibility_signposting_api/services/calculators/rule_calculator.py b/src/eligibility_signposting_api/services/calculators/rule_calculator.py index 96b0dfd87..4791aead6 100644 --- a/src/eligibility_signposting_api/services/calculators/rule_calculator.py +++ b/src/eligibility_signposting_api/services/calculators/rule_calculator.py @@ -1,13 +1,14 @@ from __future__ import annotations from collections.abc import Collection, Mapping -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any from hamcrest.core.string_description import StringDescription from eligibility_signposting_api.model import eligibility_status, rules -from eligibility_signposting_api.services.rules.operators import OperatorRegistry +from eligibility_signposting_api.services.operators.operators import OperatorRegistry +from eligibility_signposting_api.services.processors.person_data_reader import PersonDataReader Row = Collection[Mapping[str, Any]] @@ -17,6 +18,8 @@ class RuleCalculator: person_data: Row rule: rules.IterationRule + person_data_reader: PersonDataReader = field(default_factory=PersonDataReader) + def evaluate_exclusion(self) -> tuple[eligibility_status.Status, eligibility_status.Reason]: """Evaluate if a particular rule excludes this person. Return the result, and the reason for the result.""" attribute_value = self.get_attribute_value() @@ -43,15 +46,7 @@ def get_attribute_value(self) -> str | None: (r for r in self.person_data if r.get("ATTRIBUTE_TYPE", "") == "COHORTS"), None ) if cohorts: - attr_name = ( - "COHORT_MAP" - if not self.rule.attribute_name or self.rule.attribute_name == "COHORT_LABEL" - else self.rule.attribute_name - ) - cohort_map = self.get_value(cohorts, attr_name) - cohorts_dict = self.get_value(cohort_map, "cohorts") - m_dict = self.get_value(cohorts_dict, "M") - person_cohorts: set[str] = set(m_dict.keys()) + person_cohorts = self.person_data_reader.get_person_cohorts(self.person_data) attribute_value = ",".join(person_cohorts) else: attribute_value = None @@ -66,11 +61,6 @@ def get_attribute_value(self) -> str | None: raise NotImplementedError(msg) return attribute_value - @staticmethod - def get_value(dictionary: Mapping[str, Any] | None, key: str) -> dict: - v = dictionary.get(key, {}) if isinstance(dictionary, dict) else {} - return v if isinstance(v, dict) else {} - def evaluate_rule(self, attribute_value: str | None) -> tuple[eligibility_status.Status, str, bool]: """Evaluate a rule against a person data attribute. Return the result, and the reason for the result.""" matcher_class = OperatorRegistry.get(self.rule.operator) diff --git a/src/eligibility_signposting_api/services/rules/__init__.py b/src/eligibility_signposting_api/services/operators/__init__.py similarity index 100% rename from src/eligibility_signposting_api/services/rules/__init__.py rename to src/eligibility_signposting_api/services/operators/__init__.py diff --git a/src/eligibility_signposting_api/services/rules/operators.py b/src/eligibility_signposting_api/services/operators/operators.py similarity index 100% rename from src/eligibility_signposting_api/services/rules/operators.py rename to src/eligibility_signposting_api/services/operators/operators.py diff --git a/src/eligibility_signposting_api/services/processors/__init__.py b/src/eligibility_signposting_api/services/processors/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/eligibility_signposting_api/services/processors/campaign_evaluator.py b/src/eligibility_signposting_api/services/processors/campaign_evaluator.py new file mode 100644 index 000000000..286e43e7c --- /dev/null +++ b/src/eligibility_signposting_api/services/processors/campaign_evaluator.py @@ -0,0 +1,42 @@ +from collections.abc import Collection, Iterator +from itertools import groupby +from operator import attrgetter + +from wireup import service + +from eligibility_signposting_api.model import eligibility_status, rules + + +@service +class CampaignEvaluator: + """Filters and groups campaign configurations.""" + + def get_active_campaigns(self, campaign_configs: Collection[rules.CampaignConfig]) -> list[rules.CampaignConfig]: + return [cc for cc in campaign_configs if cc.campaign_live] + + def get_requested_grouped_campaigns( + self, campaign_configs: Collection[rules.CampaignConfig], conditions: list[str], category: str + ) -> Iterator[tuple[eligibility_status.ConditionName, list[rules.CampaignConfig]]]: + mapping = { + "ALL": {"V", "S"}, + "VACCINATIONS": {"V"}, + "SCREENING": {"S"}, + } + + allowed_types = mapping.get(category, set()) + + filter_all_conditions = "ALL" in conditions + + active_campaigns = self.get_active_campaigns(campaign_configs) + + for condition_name, campaign_group in groupby( + sorted(active_campaigns, key=attrgetter("target")), + key=attrgetter("target"), + ): + campaigns = list(campaign_group) + if ( + campaigns + and campaigns[0].type in allowed_types + and (filter_all_conditions or str(condition_name) in conditions) + ): + yield condition_name, campaigns diff --git a/src/eligibility_signposting_api/services/processors/person_data_reader.py b/src/eligibility_signposting_api/services/processors/person_data_reader.py new file mode 100644 index 000000000..20a8f6dca --- /dev/null +++ b/src/eligibility_signposting_api/services/processors/person_data_reader.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from collections.abc import Collection, Mapping +from typing import Any + +from wireup import service + +Row = Collection[Mapping[str, Any]] + + +@service +class PersonDataReader: + """Handles extracting and interpreting person data.""" + + def get_person_cohorts(self, person_data: Row) -> set[str]: + cohorts_row: Mapping[str, list[dict[str, str]]] = next( + (row for row in person_data if row.get("ATTRIBUTE_TYPE") == "COHORTS"), + {}, + ) + person_cohorts = set() + + for membership in cohorts_row.get("COHORT_MEMBERSHIPS", []): + if membership.get("COHORT_LABEL"): + person_cohorts.add(membership.get("COHORT_LABEL")) + + return person_cohorts diff --git a/tests/fixtures/builders/repos/person.py b/tests/fixtures/builders/repos/person.py index 6cc418e19..6b9772acb 100644 --- a/tests/fixtures/builders/repos/person.py +++ b/tests/fixtures/builders/repos/person.py @@ -66,13 +66,9 @@ def person_rows_builder( # noqa:PLR0913 { "NHS_NUMBER": key, "ATTRIBUTE_TYPE": "COHORTS", - "COHORT_MAP": { - "cohorts": { - "M": { - cohort: {"M": {"dateJoined": {"S": faker.past_date().strftime("%Y%m%d")}}} for cohort in cohorts - } - } - }, + "COHORT_MEMBERSHIPS": [ + {"COHORT_LABEL": cohort, "DATE_JOINED": faker.past_date().strftime("%Y%m%d")} for cohort in cohorts + ], }, ] rows.extend( diff --git a/tests/unit/services/calculators/test_eligibility_calculator.py b/tests/unit/services/calculators/test_eligibility_calculator.py index 7b9c39bf5..819feb0c5 100644 --- a/tests/unit/services/calculators/test_eligibility_calculator.py +++ b/tests/unit/services/calculators/test_eligibility_calculator.py @@ -5,7 +5,7 @@ from faker import Faker from flask import Flask, g from freezegun import freeze_time -from hamcrest import assert_that, contains_exactly, contains_inanyorder, equal_to, has_item, has_items, is_, is_in +from hamcrest import assert_that, contains_exactly, contains_inanyorder, equal_to, has_item, has_items, is_in from pydantic import HttpUrl, ValidationError from eligibility_signposting_api.audit.audit_models import AuditAction, AuditEvent @@ -95,7 +95,8 @@ def test_not_base_eligible(faker: Faker): target="RSV", iterations=[ rule_builder.IterationFactory.build( - iteration_cohorts=[rule_builder.IterationCohortFactory.build(cohort_label="cohort2")] + iteration_cohorts=[rule_builder.IterationCohortFactory.build(cohort_label="cohort2")], + iteration_rules=[], ) ], ) @@ -175,6 +176,7 @@ def test_only_live_campaigns_considered(faker: Faker): iterations=[ rule_builder.IterationFactory.build( iteration_cohorts=[rule_builder.IterationCohortFactory.build(cohort_label="cohort2")], + iteration_rules=[], ) ], start_date=datetime.date(2025, 4, 20), @@ -218,7 +220,7 @@ def test_campaigns_with_applicable_iteration_types_in_campaign_level_considered( # Given nhs_number = NHSNumber(faker.nhs_number()) - person_rows = person_rows_builder(nhs_number) + person_rows = person_rows_builder(nhs_number, cohorts=[]) campaign_configs = [rule_builder.CampaignConfigFactory.build(target="RSV", iteration_type=iteration_type)] calculator = EligibilityCalculator(person_rows, campaign_configs) @@ -247,7 +249,7 @@ def test_campaigns_with_applicable_iteration_types_in_iteration_level_considered # Given nhs_number = NHSNumber(faker.nhs_number()) - person_rows = person_rows_builder(nhs_number) + person_rows = person_rows_builder(nhs_number, cohorts=[]) campaign_configs = [ rule_builder.CampaignConfigFactory.build( target="RSV", iterations=[rule_builder.IterationFactory.build(type=iteration_type)] @@ -935,11 +937,6 @@ def test_status_on_target_based_on_last_successful_date( Status.not_eligible, "cohort label is the default attribute name for the cohort attribute level", ), - ( - rules.RuleAttributeName("LOCATION"), - Status.actionable, - "attribute name that is not cohort label", - ), ], ) def test_status_on_cohort_attribute_level( @@ -2383,38 +2380,6 @@ def test_should_not_include_actions_when_include_actions_flag_is_false_when_stat ) -@pytest.mark.parametrize( - ("campaign_target", "campaign_type", "conditions_filter", "category_filter", "expected_result"), - [ - # Multiple matching campaigns under the same condition - ("RSV", "V", ["RSV"], "VACCINATIONS", [("RSV", "V")]), - ("RSV", "V", ["COVID"], "VACCINATIONS", []), - ("RSV", "S", ["RSV"], "ALL", [("RSV", "S")]), - ("RSV", "S", ["ALL"], "ALL", [("RSV", "S")]), - ("RSV", "S", ["RSV"], "VACCINATIONS", []), - # Multiple campaigns with different types under the same condition name - ("RSV", "V", ["RSV"], "ALL", [("RSV", "V")]), - # Campaign is live but condition not in filter (no yield) - ("FLU", "V", ["COVID", "RSV"], "ALL", []), - # Category is ALL and condition filter includes ALL (everything matches) - ("FLU", "S", ["ALL"], "ALL", [("FLU", "S")]), - # Condition filter is unknown (should not match anything) - ("COVID", "V", ["UNKNOWN"], "VACCINATIONS", []), - # Campaign with the target matching one of several condition filters - ("FLU", "V", ["COVID", "FLU"], "VACCINATIONS", [("FLU", "V")]), - ], -) -def test_campaigns_grouped_by_condition_name_filters_correctly( - campaign_target, campaign_type, conditions_filter, category_filter, expected_result -): - campaign = rule_builder.CampaignConfigFactory.build(target=campaign_target, type=campaign_type, campaign_live=True) - - calculator = EligibilityCalculator(person_data=[], campaign_configs=[campaign]) - result = list(calculator.campaigns_grouped_by_condition_name(conditions_filter, category_filter)) - - assert_that([(str(name), group[0].type) for name, group in result], is_(expected_result)) - - @pytest.mark.parametrize( ( "test_comment", diff --git a/tests/unit/services/operators/test_operators.py b/tests/unit/services/operators/test_operators.py index 1c2b2ba70..ffb5777c5 100644 --- a/tests/unit/services/operators/test_operators.py +++ b/tests/unit/services/operators/test_operators.py @@ -3,7 +3,7 @@ from hamcrest import assert_that, equal_to from eligibility_signposting_api.model.rules import RuleOperator -from eligibility_signposting_api.services.rules.operators import Operator, OperatorRegistry +from eligibility_signposting_api.services.operators.operators import Operator, OperatorRegistry # Test cases: person_data, rule_operator, rule_value, expected, test_comment cases: list[tuple[str | None, RuleOperator, str | None, bool, str]] = [] diff --git a/tests/unit/services/processors/__init__.py b/tests/unit/services/processors/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/services/processors/test_campaign_evaluator.py b/tests/unit/services/processors/test_campaign_evaluator.py new file mode 100644 index 000000000..e968e3e9e --- /dev/null +++ b/tests/unit/services/processors/test_campaign_evaluator.py @@ -0,0 +1,117 @@ +import datetime + +import pytest +from hamcrest import assert_that, is_ + +from eligibility_signposting_api.model.rules import CampaignID +from eligibility_signposting_api.services.processors.campaign_evaluator import CampaignEvaluator +from tests.fixtures.builders.model import rule + + +@pytest.fixture +def campaign_evaluator(): + return CampaignEvaluator() + + +@pytest.mark.parametrize( + ("campaign_target", "campaign_type", "conditions_filter", "category_filter", "expected_result"), + [ + ("RSV", "V", ["RSV"], "VACCINATIONS", [("RSV", "V")]), + ("RSV", "V", ["COVID"], "VACCINATIONS", []), + ("RSV", "S", ["RSV"], "ALL", [("RSV", "S")]), + ("RSV", "S", ["ALL"], "ALL", [("RSV", "S")]), + ("RSV", "S", ["RSV"], "VACCINATIONS", []), + ("RSV", "V", ["RSV"], "ALL", [("RSV", "V")]), + ("FLU", "V", ["COVID", "RSV"], "ALL", []), + ("FLU", "S", ["ALL"], "ALL", [("FLU", "S")]), + ("COVID", "V", ["UNKNOWN"], "VACCINATIONS", []), + ("FLU", "V", ["COVID", "FLU"], "VACCINATIONS", [("FLU", "V")]), + ], +) +def test_campaigns_grouped_by_condition_name_filters_correctly( # noqa: PLR0913 + campaign_evaluator, campaign_target, campaign_type, conditions_filter, category_filter, expected_result +): + campaign = rule.CampaignConfigFactory.build(target=campaign_target, type=campaign_type) + + result = campaign_evaluator.get_requested_grouped_campaigns([campaign], conditions_filter, category_filter) + assert_that([(str(name), group[0].type) for name, group in result], is_(expected_result)) + + +def test_campaigns_grouped_by_condition_name_with_no_campaigns(campaign_evaluator): + result = campaign_evaluator.get_requested_grouped_campaigns([], ["RSV"], "VACCINATIONS") + assert_that(list(result), is_([])) + + +def test_campaigns_grouped_by_condition_name_with_no_active_campaigns(campaign_evaluator): + campaign = rule.CampaignConfigFactory.build( + target="RSV", type="V", start_date=datetime.date(2025, 4, 20), end_date=datetime.date(2025, 4, 21) + ) + + result = campaign_evaluator.get_requested_grouped_campaigns([campaign], ["RSV"], "VACCINATIONS") + assert_that(list(result), is_([])) + + +@pytest.mark.parametrize( + ("category_filter", "campaign_type", "expected_count"), + [ + ("SCREENING", "S", 1), + ("SCREENING", "V", 0), + ("INVALID_CATEGORY", "S", 0), + ], +) +def test_campaigns_grouped_by_condition_name_with_various_categories( + campaign_evaluator, category_filter, campaign_type, expected_count +): + campaign = rule.CampaignConfigFactory.build(target="COVID", type=campaign_type) + result = list(campaign_evaluator.get_requested_grouped_campaigns([campaign], ["COVID"], category_filter)) + assert_that(len(result), is_(expected_count)) + if expected_count > 0: + assert_that(str(result[0][0]), is_("COVID")) + + +def test_campaigns_grouped_by_condition_name_with_empty_conditions_filter(campaign_evaluator): + campaign = rule.CampaignConfigFactory.build(target="RSV", type="V") + result = campaign_evaluator.get_requested_grouped_campaigns([campaign], [], "VACCINATIONS") + assert_that(list(result), is_([])) + + +def test_campaigns_grouped_by_condition_name_groups_multiple_campaigns_for_same_target(campaign_evaluator): + campaign1 = rule.CampaignConfigFactory.build(target="COVID", type="V", id="C1") + campaign2 = rule.CampaignConfigFactory.build(target="COVID", type="V", id="C2") + campaign3 = rule.CampaignConfigFactory.build(target="FLU", type="V", id="F1") + inactive_campaign = rule.CampaignConfigFactory.build( + target="COVID", type="V", id="C3", start_date=datetime.date(2025, 4, 20), end_date=datetime.date(2025, 4, 21) + ) + + all_campaigns = [campaign1, campaign2, campaign3, inactive_campaign] + result = list(campaign_evaluator.get_requested_grouped_campaigns(all_campaigns, ["COVID", "FLU"], "VACCINATIONS")) + + assert_that(len(result), is_(2)) + + result_dict = {str(name): campaigns for name, campaigns in result} + assert_that("COVID" in result_dict) + assert_that("FLU" in result_dict) + + assert_that(len(result_dict["COVID"]), is_(2)) + assert_that({c.id for c in result_dict["COVID"]}, is_({CampaignID("C1"), CampaignID("C2")})) + + assert_that(len(result_dict["FLU"]), is_(1)) + assert_that(result_dict["FLU"][0].id, is_(CampaignID("F1"))) + + +def test_campaign_grouping_is_affected_by_order_for_mixed_types(campaign_evaluator): + campaign_v = rule.CampaignConfigFactory.build(target="RSV", type="V") + campaign_s = rule.CampaignConfigFactory.build(target="RSV", type="S") + + evaluator_s_first = campaign_evaluator + result_s_first = list( + evaluator_s_first.get_requested_grouped_campaigns([campaign_s, campaign_v], ["RSV"], "VACCINATIONS") + ) + assert_that(result_s_first, is_([])) + + evaluator_v_first = campaign_evaluator + result_v_first = list( + evaluator_v_first.get_requested_grouped_campaigns([campaign_v, campaign_s], ["RSV"], "VACCINATIONS") + ) + assert_that(len(result_v_first), is_(1)) + assert_that(len(result_v_first[0][1]), is_(2)) diff --git a/tests/unit/services/processors/test_person_data_reader.py b/tests/unit/services/processors/test_person_data_reader.py new file mode 100644 index 000000000..2191c44b3 --- /dev/null +++ b/tests/unit/services/processors/test_person_data_reader.py @@ -0,0 +1,86 @@ +import pytest +from hamcrest import assert_that, is_ + +from eligibility_signposting_api.services.processors.person_data_reader import PersonDataReader + + +@pytest.fixture +def person_data_reader(): + return PersonDataReader() + + +def test_get_person_cohorts_empty_data(person_data_reader): + result = person_data_reader.get_person_cohorts([]) + assert_that(result, is_(set())) + + +def test_get_person_cohorts_no_cohorts_attribute_type(person_data_reader): + no_cohorts_type = [ + {"ATTRIBUTE_TYPE": "NAME", "VALUE": "John Doe"}, + {"ATTRIBUTE_TYPE": "AGE", "VALUE": 30}, + ] + result = person_data_reader.get_person_cohorts(no_cohorts_type) + assert_that(result, is_(set())) + + +def test_get_person_cohorts_no_cohort_map_key(person_data_reader): + no_cohorts_map = [ + {"ATTRIBUTE_TYPE": "COHORTS", "OTHER_FIELD": "value"}, + ] + result = person_data_reader.get_person_cohorts(no_cohorts_map) + assert_that(result, is_(set())) + + +def test_get_person_cohorts_single_cohort(person_data_reader): + single_cohorts = [ + { + "ATTRIBUTE_TYPE": "COHORTS", + "COHORT_MEMBERSHIPS": [{"COHORT_LABEL": "flu_65+_autumnwinter2023", "DATE_JOINED": "20231020"}], + }, + {"ATTRIBUTE_TYPE": "NAME", "VALUE": "Jane Smith"}, + ] + result = person_data_reader.get_person_cohorts(single_cohorts) + assert_that(result, is_({"flu_65+_autumnwinter2023"})) + + +def test_get_person_cohorts_multiple_cohorts(person_data_reader): + multiple_cohorts = [ + { + "ATTRIBUTE_TYPE": "COHORTS", + "COHORT_MEMBERSHIPS": [ + {"COHORT_LABEL": "COHORT_B", "DATE_JOINED": "20231020"}, + {"COHORT_LABEL": "COHORT_C", "DATE_JOINED": "20241020"}, + ], + }, + {"ATTRIBUTE_TYPE": "AGE", "VALUE": 45}, + ] + result = person_data_reader.get_person_cohorts(multiple_cohorts) + assert_that(result, is_({"COHORT_B", "COHORT_C"})) + + +def test_get_person_cohorts_mixed_data(person_data_reader): + mixed_data = [ + { + "ATTRIBUTE_TYPE": "COHORTS", + "COHORT_MEMBERSHIPS": [ + {"COHORT_LABEL": "COHORT_D", "DATE_JOINED": "20231020"}, + {"COHORT_LABEL": "COHORT_E", "DATE_JOINED": "20241020"}, + ], + }, + {"ATTRIBUTE_TYPE": "NAME", "VALUE": "Alice"}, + {"ATTRIBUTE_TYPE": "ADDRESS", "VALUE": "123 Main St"}, + ] + + result = person_data_reader.get_person_cohorts(mixed_data) + assert_that(result, is_({"COHORT_D", "COHORT_E"})) + + +def test_get_person_cohorts_with_other_attribute_types_present(person_data_reader): + data = [ + {"ATTRIBUTE_TYPE": "COHORTS", "COHORT_MEMBERSHIPS": [{"COHORT_LABEL": "COHORT_F", "DATE_JOINED": "20231020"}]}, + {"ATTRIBUTE_TYPE": "NAME", "VALUE": "Charlie"}, + {"ATTRIBUTE_TYPE": "AGE", "VALUE": 25}, + ] + + result = person_data_reader.get_person_cohorts(data) + assert_that(result, is_({"COHORT_F"}))