Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ class AddDaysHandler(DerivedValueHandler):

function_name: str = "ADD_DAYS"

# Mapping of derived attribute names to their source attributes
DERIVED_ATTRIBUTE_SOURCES: ClassVar[dict[str, str]] = {
"NEXT_DOSE_DUE": "LAST_SUCCESSFUL_DATE",
}
Expand Down Expand Up @@ -62,7 +61,6 @@ def get_source_attribute(self, target_attribute: str, function_args: str | None
The source attribute name (e.g., 'LAST_SUCCESSFUL_DATE')
"""
if function_args and "," in function_args:
# Extract source from args if present (second argument)
parts = [p.strip() for p in function_args.split(",")]
if len(parts) > 1 and parts[1]:
return parts[1].upper()
Expand Down Expand Up @@ -98,6 +96,9 @@ def calculate(self, context: DerivedValueContext) -> str:
def _find_source_date(self, context: DerivedValueContext) -> str | None:
"""Find the source date value from person data.

For PERSON/COHORT-level attributes, looks for ATTRIBUTE_TYPE == attribute_level.
For TARGET-level attributes, looks for ATTRIBUTE_TYPE == context.attribute_name (e.g., "COVID").

Args:
context: The derived value context

Expand All @@ -108,8 +109,13 @@ def _find_source_date(self, context: DerivedValueContext) -> str | None:
if not source_attr:
return None

if context.attribute_level in ("PERSON", "COHORT"):
attribute_type_to_match = context.attribute_level
else:
attribute_type_to_match = context.attribute_name

for attribute in context.person_data:
if attribute.get("ATTRIBUTE_TYPE") == context.attribute_name:
if attribute.get("ATTRIBUTE_TYPE") == attribute_type_to_match:
return attribute.get(source_attr)

return None
Expand All @@ -128,7 +134,6 @@ def _get_days_to_add(self, context: DerivedValueContext) -> int:
Returns:
Number of days to add
"""
# Priority 1: Token argument (if non-empty)
if context.function_args:
args = context.function_args.split(",")[0].strip()
if args:
Expand All @@ -138,11 +143,9 @@ def _get_days_to_add(self, context: DerivedValueContext) -> int:
message = f"Invalid days argument '{args}' for ADD_DAYS function. Expected an integer."
raise ValueError(message) from e

# Priority 2: Vaccine-specific configuration
if context.attribute_name in self.vaccine_type_days:
return self.vaccine_type_days[context.attribute_name]

# Priority 3: Default
return self.default_days

def _add_days_to_date(self, date_str: str, days: int) -> datetime:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,20 @@ class DerivedValueContext:

Attributes:
person_data: List of person attribute dictionaries
attribute_name: The condition/vaccine type (e.g., 'COVID', 'RSV')
attribute_name: The condition/vaccine type (e.g., 'COVID', 'RSV') or person/cohort attribute
(e.g., 'DATE_OF_BIRTH')
source_attribute: The source attribute to derive from (e.g., 'LAST_SUCCESSFUL_DATE')
function_args: Arguments passed to the function (e.g., number of days)
date_format: Optional date format string for output formatting
attribute_level: The level of the attribute ('TARGET', 'PERSON' or 'COHORT')
"""

person_data: list[dict[str, Any]]
attribute_name: str
source_attribute: str | None
function_args: str | None
date_format: str | None
attribute_level: str = "TARGET"


class DerivedValueHandler(ABC):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,15 @@ def get_token_replacement(token: str, person_data: list[dict], present_attribute
if TokenProcessor.should_replace_with_empty(parsed_token, present_attributes):
return ""

# Check if this is a derived value (has a function like ADD_DAYS)
if parsed_token.function_name:
return TokenProcessor.get_derived_value(parsed_token, person_data, present_attributes, token)

TokenProcessor.validate_target_attribute(parsed_token, token)

found_attribute, key_to_replace = TokenProcessor.find_matching_attribute(parsed_token, person_data)

if not found_attribute or not key_to_replace:
TokenProcessor.handle_token_not_found(parsed_token, token)
# handle_token_not_found always raises, but the type checker needs help
msg = "Unreachable"
raise RuntimeError(msg) # pragma: no cover

Expand All @@ -113,6 +113,11 @@ def get_derived_value(
) -> str:
"""Calculate a derived value using the registered handler.

For TARGET level tokens, validates that the condition is allowed before processing.
If the vaccine type is not in person data, returns an empty string.
For derived values, any target attribute name is allowed (e.g., NEXT_BOOKING_AVAILABLE)
since it's just a placeholder that may be surfaced in the future.

Args:
parsed_token: The parsed token containing function information
person_data: List of person attribute dictionaries
Expand All @@ -136,20 +141,14 @@ def get_derived_value(
message = f"Unknown function '{function_name}' in token '{token}'."
raise ValueError(message)

# For TARGET level tokens, validate the condition is allowed
if parsed_token.attribute_level == TARGET_ATTRIBUTE_LEVEL:
is_allowed_condition = parsed_token.attribute_name in ALLOWED_CONDITIONS.__args__
is_allowed_target_attr = parsed_token.attribute_value in ALLOWED_TARGET_ATTRIBUTES

# If condition is not allowed, raise error
if not is_allowed_condition:
TokenProcessor.handle_token_not_found(parsed_token, token)

# If vaccine type is not in person data but is allowed, return empty
if parsed_token.attribute_name not in present_attributes:
if is_allowed_target_attr:
return ""
TokenProcessor.handle_token_not_found(parsed_token, token)
return ""

try:
target_attribute = parsed_token.attribute_value or parsed_token.attribute_name
Expand All @@ -165,14 +164,14 @@ def get_derived_value(
source_attribute=source_attribute,
function_args=parsed_token.function_args,
date_format=parsed_token.format,
attribute_level=parsed_token.attribute_level,
)

return registry.calculate(
function_name=function_name,
context=context,
)
except ValueError as e:
# Re-raise with more context
message = f"Error calculating derived value for token '{token}': {e}"
raise ValueError(message) from e

Expand All @@ -185,6 +184,26 @@ def should_replace_with_empty(parsed_token: ParsedToken, present_attributes: set

return all([is_target_level, is_allowed_condition, is_allowed_target_attr, is_attr_not_present])

@staticmethod
def validate_target_attribute(parsed_token: ParsedToken, token: str) -> None:
"""Validate that target attribute is allowed for non-derived tokens.

For regular (non-derived) tokens, only allow known target attributes.
Derived values with functions can use any custom target attribute name.

Args:
parsed_token: The parsed token to validate
token: The original token string for error messages

Raises:
ValueError: If the target attribute is not in ALLOWED_TARGET_ATTRIBUTES
"""
if (
parsed_token.attribute_level == TARGET_ATTRIBUTE_LEVEL
and parsed_token.attribute_value not in ALLOWED_TARGET_ATTRIBUTES
):
TokenProcessor.handle_token_not_found(parsed_token, token)

@staticmethod
def find_matching_attribute(parsed_token: ParsedToken, person_data: list[dict]) -> tuple[dict | None, str | None]:
attribute_level_map = {
Expand Down
56 changes: 56 additions & 0 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,6 +1162,48 @@ def campaign_config_with_multiple_add_days(
s3_client.delete_object(Bucket=rules_bucket, Key=f"{campaign.name}.json")


@pytest.fixture
def campaign_config_with_custom_target_attributes(
s3_client: BaseClient, rules_bucket: BucketName
) -> Generator[CampaignConfig]:
"""Campaign config with custom target attribute names for derived values."""
campaign: CampaignConfig = rule.CampaignConfigFactory.build(
target="COVID",
iterations=[
rule.IterationFactory.build(
default_comms_routing="CUSTOM_BOOKING_DATE",
actions_mapper=rule.ActionsMapperFactory.build(
root={
"CUSTOM_BOOKING_DATE": AvailableAction(
ActionType="DataValue",
ExternalRoutingCode="NextBookingAvailable",
ActionDescription=(
"[[TARGET.COVID.NEXT_BOOKING_AVAILABLE:ADD_DAYS(71, LAST_SUCCESSFUL_DATE):"
"DATE(%d %B %Y)]]"
),
),
}
),
iteration_rules=[],
iteration_cohorts=[
rule.IterationCohortFactory.build(
cohort_label="cohort_label1",
cohort_group="cohort_group1",
positive_description="Positive Description",
negative_description="Negative Description",
)
],
)
],
)
campaign_data = {"CampaignConfig": campaign.model_dump(by_alias=True)}
s3_client.put_object(
Bucket=rules_bucket, Key=f"{campaign.name}.json", Body=json.dumps(campaign_data), ContentType="application/json"
)
yield campaign
s3_client.delete_object(Bucket=rules_bucket, Key=f"{campaign.name}.json")


@pytest.fixture(scope="class")
def multiple_campaign_configs(s3_client: BaseClient, rules_bucket: BucketName) -> Generator[list[CampaignConfig]]:
"""Create and upload multiple campaign configs to S3, then clean up after tests."""
Expand Down Expand Up @@ -1521,6 +1563,20 @@ def consumer_to_active_campaign_config_with_multiple_add_days_mapping(
s3_client.delete_object(Bucket=consumer_mapping_bucket, Key="consumer_mapping.json")


@pytest.fixture
def consumer_to_active_campaign_config_with_custom_target_attributes_mapping(
s3_client: BaseClient,
consumer_mapping_bucket: ConsumerMapping,
campaign_config_with_custom_target_attributes: CampaignConfig,
consumer_id: ConsumerId,
):
consumer_mapping = create_and_put_consumer_mapping_in_s3(
campaign_config_with_custom_target_attributes, consumer_id, consumer_mapping_bucket, s3_client
)
yield consumer_mapping
s3_client.delete_object(Bucket=consumer_mapping_bucket, Key="consumer_mapping.json")


@pytest.fixture
def consumer_to_campaign_having_inactive_iteration_mapping(
s3_client: BaseClient,
Expand Down
62 changes: 62 additions & 0 deletions tests/integration/in_process/test_derived_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,65 @@ def test_multiple_actions_with_different_add_days_parameters(
has_item(has_entries(actionCode="DateOfNextDoseAt61Days", description="20260330")),
),
)


class TestCustomTargetAttributeNames:
"""Test that custom target attribute names work with derived values in integration."""

def test_custom_target_attribute_with_derived_value(
self,
client: FlaskClient,
person_with_covid_vaccination: NHSNumber,
consumer_id: ConsumerId,
consumer_to_active_campaign_config_with_custom_target_attributes_mapping: ConsumerMapping, # noqa: ARG002
secretsmanager_client: BaseClient, # noqa: ARG002
):
"""
Test that custom target attribute names like NEXT_BOOKING_AVAILABLE work with derived values.

This tests the issue reported in production where:
[[TARGET.COVID.NEXT_BOOKING_AVAILABLE:ADD_DAYS(71, LAST_SUCCESSFUL_DATE):DATE(%d %B %Y)]]
was raising a ValueError.

Given:
- A person with COVID vaccination on 2026-01-28
- A campaign config using a custom target attribute: NEXT_BOOKING_AVAILABLE
- The token: [[TARGET.COVID.NEXT_BOOKING_AVAILABLE:ADD_DAYS(71, LAST_SUCCESSFUL_DATE):DATE(%d %B %Y)]]

Expected:
- Should calculate 2026-01-28 + 71 days = 2026-04-09
- Should format as "09 April 2026"
- Should NOT raise ValueError about invalid attribute name
"""
# Given
headers = {"nhs-login-nhs-number": str(person_with_covid_vaccination), UNIQUE_CONSUMER_HEADER: str(consumer_id)}

# When
response = client.get(
f"/patient-check/{person_with_covid_vaccination}?includeActions=Y",
headers=headers,
)

# Then
assert_that(
response,
is_response().with_status_code(HTTPStatus.OK).and_text(is_json_that(has_key("processedSuggestions"))),
)

body = response.get_json()
assert_that(body, is_not(none()))
processed_suggestions = body.get("processedSuggestions", [])

covid_suggestion = next(
(s for s in processed_suggestions if s.get("condition") == "COVID"),
None,
)
assert_that(covid_suggestion, is_not(none()))

actions = covid_suggestion.get("actions", []) # type: ignore[union-attr]

# Verify the custom target attribute with derived value works correctly
assert_that(
actions,
has_item(has_entries(actionCode="NextBookingAvailable", description="09 April 2026")),
)
32 changes: 32 additions & 0 deletions tests/unit/services/processors/test_derived_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,38 @@ def test_function_args_priority_over_vaccine_config(self):
# 2025-01-01 + 30 days = 2025-01-31
assert_that(result, is_(equal_to("20250131")))

def test_calculate_returns_empty_when_source_attribute_is_none(self):
"""Test that empty string is returned when source_attribute is None."""
handler = AddDaysHandler(default_days=91)
context = DerivedValueContext(
person_data=[{"ATTRIBUTE_TYPE": "COVID", "LAST_SUCCESSFUL_DATE": "20250101"}],
attribute_name="COVID",
source_attribute=None, # This should cause early return None in _find_source_date
function_args=None,
date_format=None,
attribute_level="TARGET",
)

result = handler.calculate(context)

assert_that(result, is_(equal_to("")))

def test_calculate_returns_empty_when_source_attribute_is_empty(self):
"""Test that empty string is returned when source_attribute is empty string."""
handler = AddDaysHandler(default_days=91)
context = DerivedValueContext(
person_data=[{"ATTRIBUTE_TYPE": "COVID", "LAST_SUCCESSFUL_DATE": "20250101"}],
attribute_name="COVID",
source_attribute="", # This should cause early return None in _find_source_date
function_args=None,
date_format=None,
attribute_level="TARGET",
)

result = handler.calculate(context)

assert_that(result, is_(equal_to("")))


class TestDerivedValueRegistry:
"""Tests for the DerivedValueRegistry class."""
Expand Down
Loading
Loading