diff --git a/src/eligibility_signposting_api/app.py b/src/eligibility_signposting_api/app.py index 6246a807c..70375fafe 100644 --- a/src/eligibility_signposting_api/app.py +++ b/src/eligibility_signposting_api/app.py @@ -10,7 +10,8 @@ from eligibility_signposting_api import audit, repos, services from eligibility_signposting_api.common.error_handler import handle_exception from eligibility_signposting_api.common.request_validator import validate_request_params -from eligibility_signposting_api.config.config import config, init_logging +from eligibility_signposting_api.config.config import config +from eligibility_signposting_api.logging.logs_manager import add_request_id_to_logs, init_logging from eligibility_signposting_api.views import eligibility_blueprint init_logging() @@ -23,6 +24,7 @@ def main() -> None: # pragma: no cover app.run(debug=config()["log_level"] == logging.DEBUG) +@add_request_id_to_logs() @validate_request_params() def lambda_handler(event: LambdaEvent, context: LambdaContext) -> dict[str, Any]: # pragma: no cover """Run the Flask app as an AWS Lambda.""" diff --git a/src/eligibility_signposting_api/audit/audit_models.py b/src/eligibility_signposting_api/audit/audit_models.py index 17467130f..d80409a8c 100644 --- a/src/eligibility_signposting_api/audit/audit_models.py +++ b/src/eligibility_signposting_api/audit/audit_models.py @@ -86,7 +86,7 @@ class AuditCondition(CamelCaseBaseModel): class ResponseAuditData(CamelCaseBaseModel): response_id: UUID | None = None - last_updated: str | None = None + last_updated: datetime | None = None condition: list[AuditCondition] = Field(default_factory=list) diff --git a/src/eligibility_signposting_api/config/config.py b/src/eligibility_signposting_api/config/config.py index 722e90133..49faeff6b 100644 --- a/src/eligibility_signposting_api/config/config.py +++ b/src/eligibility_signposting_api/config/config.py @@ -1,10 +1,8 @@ import logging import os -from collections.abc import Sequence from functools import cache from typing import Any, NewType -from pythonjsonlogger.json import JsonFormatter from yarl import URL from eligibility_signposting_api.repos.campaign_repo import BucketName @@ -57,16 +55,3 @@ def config() -> dict[str, Any]: "kinesis_audit_stream_to_s3": kinesis_audit_stream_to_s3, "log_level": log_level, } - - -def init_logging(quieten: Sequence[str] = ("asyncio", "botocore", "boto3", "mangum", "urllib3")) -> None: - log_format = "%(asctime)s %(levelname)-8s %(name)s %(module)s.py:%(funcName)s():%(lineno)d %(message)s" - formatter = JsonFormatter(log_format) - handler = logging.StreamHandler() - handler.setFormatter(formatter) - logging.root.handlers = [] # Clear any existing handlers - logging.root.setLevel(LOG_LEVEL) # Set log level - logging.root.addHandler(handler) # Add handler - - for q in quieten: - logging.getLogger(q).setLevel(logging.WARNING) diff --git a/src/eligibility_signposting_api/logging/__init__.py b/src/eligibility_signposting_api/logging/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/eligibility_signposting_api/logging/logs_manager.py b/src/eligibility_signposting_api/logging/logs_manager.py new file mode 100644 index 000000000..093c06dda --- /dev/null +++ b/src/eligibility_signposting_api/logging/logs_manager.py @@ -0,0 +1,48 @@ +import logging +from collections.abc import Callable, Sequence +from contextvars import ContextVar +from functools import wraps +from typing import Any + +from mangum.types import LambdaContext, LambdaEvent +from pythonjsonlogger.json import JsonFormatter + +from eligibility_signposting_api.config.config import LOG_LEVEL + +request_id_context_var: ContextVar[str | None] = ContextVar("request_id", default=None) + +LOG_FORMAT = "%(asctime)s %(levelname)-8s %(name)s %(module)s.py:%(funcName)s():%(lineno)d %(message)s" + + +def add_request_id_to_logs() -> Callable: + def decorator(func: Callable) -> Callable: + @wraps(func) + def wrapper(event: LambdaEvent, context: LambdaContext) -> dict[str, Any] | None: + aws_request_id = request_id_context_var.set(context.aws_request_id) + try: + return func(event, context) + finally: + request_id_context_var.reset(aws_request_id) + + return wrapper + + return decorator + + +class EnrichedJsonFormatter(JsonFormatter): + def add_fields(self, log_record: dict[str, Any], record: logging.LogRecord, message_dict: dict[str, Any]) -> None: + log_record["request_id"] = request_id_context_var.get() or "-" + super().add_fields(log_record, record, message_dict) + + +def init_logging(quieten: Sequence[str] = ("asyncio", "botocore", "boto3", "mangum", "urllib3")) -> None: + formatter = EnrichedJsonFormatter(LOG_FORMAT) + handler = logging.StreamHandler() + handler.setFormatter(formatter) + + logging.root.handlers = [] # Remove default handlers + logging.root.setLevel(LOG_LEVEL) + logging.root.addHandler(handler) + + for q in quieten: + logging.getLogger(q).setLevel(logging.WARNING) diff --git a/tests/unit/logging/__init__.py b/tests/unit/logging/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/logging/test_logs_manager.py b/tests/unit/logging/test_logs_manager.py new file mode 100644 index 000000000..eb9c38a6c --- /dev/null +++ b/tests/unit/logging/test_logs_manager.py @@ -0,0 +1,120 @@ +import io +import json +import logging +import threading +from http import HTTPStatus +from unittest.mock import MagicMock, Mock + +import pytest +from mangum.types import LambdaContext + +from eligibility_signposting_api.logging.logs_manager import ( + LOG_FORMAT, + EnrichedJsonFormatter, + add_request_id_to_logs, + request_id_context_var, +) + + +def test_decorator_sets_request_id_in_context(): + test_request_id = "test-id-12345" + mock_context = MagicMock() + mock_context.aws_request_id = test_request_id + + @add_request_id_to_logs() + def decorated_handler(event, context): # noqa : ARG001 + return request_id_context_var.get() + + result = decorated_handler({}, mock_context) + + assert result == test_request_id + + +def test_decorator_preserves_function_return_value(): + expected_result = {"statusCode": 200, "body": "Success"} + mock_context = MagicMock() + mock_context.aws_request_id = "any-id" + + @add_request_id_to_logs() + def decorated_handler(event, context): # noqa : ARG001 + return expected_result + + result = decorated_handler({}, mock_context) + + assert result == expected_result + + +def test_request_id_context_is_properly_isolated(): + results = {} + + @add_request_id_to_logs() + def decorated_handler(event, context): # noqa : ARG001 + rid = request_id_context_var.get() + results[threading.current_thread().name] = rid + return rid + + def thread_func(name, rid): # noqa : ARG001 + mock_context = MagicMock(aws_request_id=rid) + decorated_handler({}, mock_context) + + threads = [ + threading.Thread(target=thread_func, name="Thread-A", args=("Thread-A", "id-A")), + threading.Thread(target=thread_func, name="Thread-B", args=("Thread-B", "id-B")), + threading.Thread(target=thread_func, name="Thread-C", args=("Thread-C", "id-C")), + ] + + for t in threads: + t.start() + for t in threads: + t.join() + + assert results["Thread-A"] == "id-A" + assert request_id_context_var.get() is None + + assert results["Thread-B"] == "id-B" + assert request_id_context_var.get() is None + + assert results["Thread-C"] == "id-C" + assert request_id_context_var.get() is None + + +@pytest.fixture +def lambda_context(): + context = Mock(spec=LambdaContext) + context.aws_request_id = "test-request-id" + return context + + +def test_enriched_json_formatter_adds_all_fields(lambda_context): + @add_request_id_to_logs() + def test_handler(event, context): # noqa : ARG001 + logger = logging.getLogger("test_logger") + logger.info("Test log inside handler") + return HTTPStatus.OK + + log_stream = io.StringIO() + handler = logging.StreamHandler(log_stream) + handler.setFormatter(EnrichedJsonFormatter(LOG_FORMAT)) + + test_logger = logging.getLogger("test_logger") + test_logger.handlers = [] + test_logger.addHandler(handler) + test_logger.setLevel(logging.INFO) + + result = test_handler({}, lambda_context) + log_output = log_stream.getvalue() + + test_logger.removeHandler(handler) + + assert result == HTTPStatus.OK + logged_json = json.loads(log_output) + + assert logged_json["request_id"] == lambda_context.aws_request_id + assert "asctime" in logged_json + assert logged_json["levelname"] == "INFO" + assert logged_json["name"] == "test_logger" + assert logged_json["module"] == "test_logs_manager" + assert logged_json["funcName"] == "test_handler" + assert "lineno" in logged_json + assert logged_json["message"] == "Test log inside handler" + assert request_id_context_var.get() is None