diff --git a/src/eligibility_signposting_api/app.py b/src/eligibility_signposting_api/app.py index 237673237..2cf614bea 100644 --- a/src/eligibility_signposting_api/app.py +++ b/src/eligibility_signposting_api/app.py @@ -11,6 +11,7 @@ from eligibility_signposting_api.config.config import config, init_logging from eligibility_signposting_api.error_handler import handle_exception from eligibility_signposting_api.views import eligibility_blueprint +from eligibility_signposting_api.wrapper import validate_matching_nhs_number init_logging() logger = logging.getLogger(__name__) @@ -22,6 +23,7 @@ def main() -> None: # pragma: no cover app.run(debug=config()["log_level"] == logging.DEBUG) +@validate_matching_nhs_number() def lambda_handler(event: LambdaEvent, context: LambdaContext) -> dict[str, Any]: # pragma: no cover """Run the Flask app as an AWS Lambda.""" app = create_app() diff --git a/src/eligibility_signposting_api/config/contants.py b/src/eligibility_signposting_api/config/contants.py index 853ef1990..9756b3081 100644 --- a/src/eligibility_signposting_api/config/contants.py +++ b/src/eligibility_signposting_api/config/contants.py @@ -1,2 +1,3 @@ MAGIC_COHORT_LABEL = "elid_all_people" RULE_STOP_DEFAULT = False +NHS_NUMBER_HEADER_NAME = "nhs-login-nhs-number" diff --git a/src/eligibility_signposting_api/wrapper.py b/src/eligibility_signposting_api/wrapper.py new file mode 100644 index 000000000..ff4ff63bc --- /dev/null +++ b/src/eligibility_signposting_api/wrapper.py @@ -0,0 +1,35 @@ +import logging +from collections.abc import Callable +from functools import wraps + +from mangum.types import LambdaContext, LambdaEvent + +from eligibility_signposting_api.config.contants import NHS_NUMBER_HEADER_NAME + +logger = logging.getLogger(__name__) + + +class MismatchedNHSNumberError(ValueError): + pass + + +def validate_matching_nhs_number() -> Callable: + def decorator(func: Callable) -> Callable: # pragma: no cover + @wraps(func) + def wrapper(event: LambdaEvent, context: LambdaContext) -> dict[str, int | str]: + headers = event.get("headers", {}) + path_params = event.get("pathParameters", {}) + + header_nhs = headers.get(NHS_NUMBER_HEADER_NAME) + path_nhs = path_params.get("id") + + logger.info("nhs numbers from the request", extra={"header_nhs": header_nhs, "path_nhs": path_nhs}) + + if header_nhs != path_nhs: + logger.error("NHS number mismatch", extra={"header_nhs_no": header_nhs, "path_nhs_no": path_nhs}) + return {"statusCode": 403, "body": "NHS number mismatch"} + return func(event, context) + + return wrapper + + return decorator diff --git a/tests/docker-compose.yml b/tests/docker-compose.yml index e6b90dd5a..8cf0ec1e8 100644 --- a/tests/docker-compose.yml +++ b/tests/docker-compose.yml @@ -9,7 +9,7 @@ services: # LocalStack configuration: https://docs.localstack.cloud/references/configuration/ - DEBUG=${LOCALSTACK_DEBUG:-0} - DEFAULT_REGION=${AWS_DEFAULT_REGION:-eu-west-1} - - LAMBDA_EXECUTOR=docker + - LAMBDA_EXECUTOR=local volumes: - "${LOCALSTACK_VOLUME_DIR:-../volume}:/var/lib/localstack" - "/var/run/docker.sock:/var/run/docker.sock" diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index d1a775e3f..6aaf6c4f7 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -62,6 +62,11 @@ def boto3_session() -> Session: return Session(aws_access_key_id="fake", aws_secret_access_key="fake", region_name=AWS_REGION) +@pytest.fixture(scope="session") +def api_gateway_client(boto3_session: Session, localstack: URL) -> BaseClient: + return boto3_session.client("apigateway", endpoint_url=str(localstack)) + + @pytest.fixture(scope="session") def lambda_client(boto3_session: Session, localstack: URL) -> BaseClient: return boto3_session.client("lambda", endpoint_url=str(localstack)) @@ -123,18 +128,42 @@ def iam_role(iam_client: BaseClient) -> Generator[str]: } ], } + dynamodb_policy = { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "dynamodb:GetItem", + "dynamodb:PutItem", + "dynamodb:UpdateItem", + "dynamodb:DeleteItem", + "dynamodb:Scan", + "dynamodb:Query", + ], + "Resource": "arn:aws:dynamodb:*:*:table/*", + } + ], + } - # Create the IAM Policy - policy = iam_client.create_policy(PolicyName=policy_name, PolicyDocument=json.dumps(log_policy)) - policy_arn = policy["Policy"]["Arn"] + # Create CloudWatch Logs policy (as before) + log_policy_resp = iam_client.create_policy(PolicyName=policy_name, PolicyDocument=json.dumps(log_policy)) + log_policy_arn = log_policy_resp["Policy"]["Arn"] + iam_client.attach_role_policy(RoleName=role_name, PolicyArn=log_policy_arn) - # Attach Policy to Role - iam_client.attach_role_policy(RoleName=role_name, PolicyArn=policy_arn) + # Create DynamoDB policy + ddb_policy_resp = iam_client.create_policy( + PolicyName="LambdaDynamoDBPolicy", PolicyDocument=json.dumps(dynamodb_policy) + ) + ddb_policy_arn = ddb_policy_resp["Policy"]["Arn"] + iam_client.attach_role_policy(RoleName=role_name, PolicyArn=ddb_policy_arn) yield role["Role"]["Arn"] - iam_client.detach_role_policy(RoleName=role_name, PolicyArn=policy_arn) - iam_client.delete_policy(PolicyArn=policy_arn) + iam_client.detach_role_policy(RoleName=role_name, PolicyArn=log_policy_arn) + iam_client.delete_policy(PolicyArn=log_policy_arn) + iam_client.detach_role_policy(RoleName=role_name, PolicyArn=ddb_policy_arn) + iam_client.delete_policy(PolicyArn=ddb_policy_arn) iam_client.delete_role(RoleName=role_name) @@ -194,6 +223,71 @@ def wait_for_function_active(function_name, lambda_client): raise FunctionNotActiveError +@pytest.fixture(scope="session") +def configured_api_gateway(api_gateway_client, lambda_client, flask_function: str): + region = lambda_client.meta.region_name + + api = api_gateway_client.create_rest_api(name="API Gateway Lambda integration") + rest_api_id = api["id"] + + resources = api_gateway_client.get_resources(restApiId=rest_api_id) + root_id = next(item["id"] for item in resources["items"] if item["path"] == "/") + + patient_check_res = api_gateway_client.create_resource( + restApiId=rest_api_id, parentId=root_id, pathPart="patient-check" + ) + patient_check_id = patient_check_res["id"] + + id_res = api_gateway_client.create_resource(restApiId=rest_api_id, parentId=patient_check_id, pathPart="{id}") + resource_id = id_res["id"] + + api_gateway_client.put_method( + restApiId=rest_api_id, + resourceId=resource_id, + httpMethod="GET", + authorizationType="NONE", + requestParameters={"method.request.path.id": True}, + ) + + # Integration with actual region + lambda_uri = ( + f"arn:aws:apigateway:{region}:lambda:path/2015-03-31/functions/" + f"arn:aws:lambda:{region}:000000000000:function:{flask_function}/invocations" + ) + api_gateway_client.put_integration( + restApiId=rest_api_id, + resourceId=resource_id, + httpMethod="GET", + type="AWS_PROXY", + integrationHttpMethod="POST", + uri=lambda_uri, + passthroughBehavior="WHEN_NO_MATCH", + ) + + # Permission with matching region + lambda_client.add_permission( + FunctionName=flask_function, + StatementId="apigateway-access", + Action="lambda:InvokeFunction", + Principal="apigateway.amazonaws.com", + SourceArn=f"arn:aws:execute-api:{region}:000000000000:{rest_api_id}/*/GET/patient-check/*", + ) + + # Deploy the API + api_gateway_client.create_deployment(restApiId=rest_api_id, stageName="dev") + + return { + "rest_api_id": rest_api_id, + "resource_id": resource_id, + "invoke_url": f"http://{rest_api_id}.execute-api.localhost.localstack.cloud:4566/dev/patient-check/{{id}}", + } + + +@pytest.fixture +def api_gateway_endpoint(configured_api_gateway: dict) -> URL: + return URL(f"http://{configured_api_gateway['rest_api_id']}.execute-api.localhost.localstack.cloud:4566/dev") + + @pytest.fixture(scope="session") def person_table(dynamodb_resource: ServiceResource) -> Generator[Any]: table = dynamodb_resource.create_table( diff --git a/tests/integration/lambda/test_app_running_as_lambda.py b/tests/integration/lambda/test_app_running_as_lambda.py index a8acc5374..360410f4b 100644 --- a/tests/integration/lambda/test_app_running_as_lambda.py +++ b/tests/integration/lambda/test_app_running_as_lambda.py @@ -34,7 +34,12 @@ def test_install_and_call_lambda_flask( "routeKey": "GET /", "rawPath": "/", "rawQueryString": "", - "headers": {"accept": "application/json", "content-type": "application/json"}, + "headers": { + "accept": "application/json", + "content-type": "application/json", + "nhs-login-nhs-number": str(persisted_person), + }, + "pathParameters": {"id": str(persisted_person)}, "requestContext": { "http": { "sourceIp": "192.0.0.1", @@ -68,15 +73,19 @@ def test_install_and_call_lambda_flask( def test_install_and_call_flask_lambda_over_http( - flask_function_url: URL, persisted_person: NHSNumber, campaign_config: CampaignConfig, # noqa: ARG001 + api_gateway_endpoint: URL, ): - """Given lambda installed into localstack, run it via http""" + """Given api-gateway and lambda installed into localstack, run it via http""" # Given - # When - response = httpx.get(str(flask_function_url / "patient-check" / persisted_person)) + invoke_url = f"{api_gateway_endpoint}/patient-check/{persisted_person}" + response = httpx.get( + invoke_url, + headers={"nhs-login-nhs-number": str(persisted_person)}, + timeout=10, + ) # Then assert_that( @@ -86,10 +95,10 @@ def test_install_and_call_flask_lambda_over_http( def test_install_and_call_flask_lambda_with_unknown_nhs_number( - flask_function_url: URL, flask_function: str, campaign_config: CampaignConfig, # noqa: ARG001 logs_client: BaseClient, + api_gateway_endpoint: URL, faker: Faker, ): """Given lambda installed into localstack, run it via http, with a nonexistent NHS number specified""" @@ -97,7 +106,12 @@ def test_install_and_call_flask_lambda_with_unknown_nhs_number( nhs_number = NHSNumber(faker.nhs_number()) # When - response = httpx.get(str(flask_function_url / "patient-check" / nhs_number)) + invoke_url = f"{api_gateway_endpoint}/patient-check/{nhs_number}" + response = httpx.get( + invoke_url, + headers={"nhs-login-nhs-number": str(nhs_number)}, + timeout=10, + ) # Then assert_that( @@ -136,3 +150,47 @@ def get_log_messages(flask_function: str, logs_client: BaseClient) -> list[str]: logGroupName=f"/aws/lambda/{flask_function}", logStreamName=log_stream_name, limit=100 ) return [e["message"] for e in log_events["events"]] + + +def test_given_nhs_number_in_path_matches_with_nhs_number_in_headers( + lambda_client: BaseClient, # noqa:ARG001 + persisted_person: NHSNumber, + campaign_config: CampaignConfig, # noqa:ARG001 + api_gateway_endpoint: URL, +): + # Given + # When + invoke_url = f"{api_gateway_endpoint}/patient-check/{persisted_person}" + response = httpx.get( + invoke_url, + headers={"nhs-login-nhs-number": str(persisted_person)}, + timeout=10, + ) + + # Then + assert_that( + response, + is_response().with_status_code(HTTPStatus.OK).and_body(is_json_that(has_key("processedSuggestions"))), + ) + + +def test_given_nhs_number_in_path_does_not_match_with_nhs_number_in_headers_results_in_error_response( + lambda_client: BaseClient, # noqa:ARG001 + persisted_person: NHSNumber, + campaign_config: CampaignConfig, # noqa:ARG001 + api_gateway_endpoint: URL, +): + # Given + # When + invoke_url = f"{api_gateway_endpoint}/patient-check/{persisted_person}" + response = httpx.get( + invoke_url, + headers={"nhs-login-nhs-number": f"123{persisted_person!s}"}, + timeout=10, + ) + + # Then + assert_that( + response, + is_response().with_status_code(HTTPStatus.FORBIDDEN).and_body("NHS number mismatch"), + )