Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/eligibility_signposting_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from eligibility_signposting_api import audit, repos, services
from eligibility_signposting_api.common.error_handler import handle_exception
from eligibility_signposting_api.common.request_validator import validate_request_params
from eligibility_signposting_api.config.config import config, init_logging
from eligibility_signposting_api.config.config import config
from eligibility_signposting_api.logging.logs_manager import add_request_id_to_logs, init_logging
from eligibility_signposting_api.views import eligibility_blueprint

init_logging()
Expand All @@ -23,6 +24,7 @@ def main() -> None: # pragma: no cover
app.run(debug=config()["log_level"] == logging.DEBUG)


@add_request_id_to_logs()
@validate_request_params()
def lambda_handler(event: LambdaEvent, context: LambdaContext) -> dict[str, Any]: # pragma: no cover
"""Run the Flask app as an AWS Lambda."""
Expand Down
2 changes: 1 addition & 1 deletion src/eligibility_signposting_api/audit/audit_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class AuditCondition(CamelCaseBaseModel):

class ResponseAuditData(CamelCaseBaseModel):
response_id: UUID | None = None
last_updated: str | None = None
last_updated: datetime | None = None
condition: list[AuditCondition] = Field(default_factory=list)


Expand Down
15 changes: 0 additions & 15 deletions src/eligibility_signposting_api/config/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import logging
import os
from collections.abc import Sequence
from functools import cache
from typing import Any, NewType

from pythonjsonlogger.json import JsonFormatter
from yarl import URL

from eligibility_signposting_api.repos.campaign_repo import BucketName
Expand Down Expand Up @@ -57,16 +55,3 @@ def config() -> dict[str, Any]:
"kinesis_audit_stream_to_s3": kinesis_audit_stream_to_s3,
"log_level": log_level,
}


def init_logging(quieten: Sequence[str] = ("asyncio", "botocore", "boto3", "mangum", "urllib3")) -> None:
log_format = "%(asctime)s %(levelname)-8s %(name)s %(module)s.py:%(funcName)s():%(lineno)d %(message)s"
formatter = JsonFormatter(log_format)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logging.root.handlers = [] # Clear any existing handlers
logging.root.setLevel(LOG_LEVEL) # Set log level
logging.root.addHandler(handler) # Add handler

for q in quieten:
logging.getLogger(q).setLevel(logging.WARNING)
Empty file.
48 changes: 48 additions & 0 deletions src/eligibility_signposting_api/logging/logs_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import logging
from collections.abc import Callable, Sequence
from contextvars import ContextVar
from functools import wraps
from typing import Any

from mangum.types import LambdaContext, LambdaEvent
from pythonjsonlogger.json import JsonFormatter

from eligibility_signposting_api.config.config import LOG_LEVEL

request_id_context_var: ContextVar[str | None] = ContextVar("request_id", default=None)

LOG_FORMAT = "%(asctime)s %(levelname)-8s %(name)s %(module)s.py:%(funcName)s():%(lineno)d %(message)s"


def add_request_id_to_logs() -> Callable:
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(event: LambdaEvent, context: LambdaContext) -> dict[str, Any] | None:
aws_request_id = request_id_context_var.set(context.aws_request_id)
try:
return func(event, context)
finally:
request_id_context_var.reset(aws_request_id)

return wrapper

return decorator


class EnrichedJsonFormatter(JsonFormatter):
def add_fields(self, log_record: dict[str, Any], record: logging.LogRecord, message_dict: dict[str, Any]) -> None:
log_record["request_id"] = request_id_context_var.get() or "-"
super().add_fields(log_record, record, message_dict)


def init_logging(quieten: Sequence[str] = ("asyncio", "botocore", "boto3", "mangum", "urllib3")) -> None:
formatter = EnrichedJsonFormatter(LOG_FORMAT)
handler = logging.StreamHandler()
handler.setFormatter(formatter)

logging.root.handlers = [] # Remove default handlers
logging.root.setLevel(LOG_LEVEL)
logging.root.addHandler(handler)

for q in quieten:
logging.getLogger(q).setLevel(logging.WARNING)
Empty file.
120 changes: 120 additions & 0 deletions tests/unit/logging/test_logs_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import io
import json
import logging
import threading
from http import HTTPStatus
from unittest.mock import MagicMock, Mock

import pytest
from mangum.types import LambdaContext

from eligibility_signposting_api.logging.logs_manager import (
LOG_FORMAT,
EnrichedJsonFormatter,
add_request_id_to_logs,
request_id_context_var,
)


def test_decorator_sets_request_id_in_context():
test_request_id = "test-id-12345"
mock_context = MagicMock()
mock_context.aws_request_id = test_request_id

@add_request_id_to_logs()
def decorated_handler(event, context): # noqa : ARG001
return request_id_context_var.get()

result = decorated_handler({}, mock_context)

assert result == test_request_id


def test_decorator_preserves_function_return_value():
expected_result = {"statusCode": 200, "body": "Success"}
mock_context = MagicMock()
mock_context.aws_request_id = "any-id"

@add_request_id_to_logs()
def decorated_handler(event, context): # noqa : ARG001
return expected_result

result = decorated_handler({}, mock_context)

assert result == expected_result


def test_request_id_context_is_properly_isolated():
results = {}

@add_request_id_to_logs()
def decorated_handler(event, context): # noqa : ARG001
rid = request_id_context_var.get()
results[threading.current_thread().name] = rid
return rid

def thread_func(name, rid): # noqa : ARG001
mock_context = MagicMock(aws_request_id=rid)
decorated_handler({}, mock_context)

threads = [
threading.Thread(target=thread_func, name="Thread-A", args=("Thread-A", "id-A")),
threading.Thread(target=thread_func, name="Thread-B", args=("Thread-B", "id-B")),
threading.Thread(target=thread_func, name="Thread-C", args=("Thread-C", "id-C")),
]

for t in threads:
t.start()
for t in threads:
t.join()

assert results["Thread-A"] == "id-A"
assert request_id_context_var.get() is None

assert results["Thread-B"] == "id-B"
assert request_id_context_var.get() is None

assert results["Thread-C"] == "id-C"
assert request_id_context_var.get() is None


@pytest.fixture
def lambda_context():
context = Mock(spec=LambdaContext)
context.aws_request_id = "test-request-id"
return context


def test_enriched_json_formatter_adds_all_fields(lambda_context):
@add_request_id_to_logs()
def test_handler(event, context): # noqa : ARG001
logger = logging.getLogger("test_logger")
logger.info("Test log inside handler")
return HTTPStatus.OK

log_stream = io.StringIO()
handler = logging.StreamHandler(log_stream)
handler.setFormatter(EnrichedJsonFormatter(LOG_FORMAT))

test_logger = logging.getLogger("test_logger")
test_logger.handlers = []
test_logger.addHandler(handler)
test_logger.setLevel(logging.INFO)

result = test_handler({}, lambda_context)
log_output = log_stream.getvalue()

test_logger.removeHandler(handler)

assert result == HTTPStatus.OK
logged_json = json.loads(log_output)

assert logged_json["request_id"] == lambda_context.aws_request_id
assert "asctime" in logged_json
assert logged_json["levelname"] == "INFO"
assert logged_json["name"] == "test_logger"
assert logged_json["module"] == "test_logs_manager"
assert logged_json["funcName"] == "test_handler"
assert "lineno" in logged_json
assert logged_json["message"] == "Test log inside handler"
assert request_id_context_var.get() is None