Skip to content

Commit f0f6f05

Browse files
committed
eli-579 tightening up registry initialisation
1 parent 5f06436 commit f0f6f05

2 files changed

Lines changed: 54 additions & 6 deletions

File tree

src/eligibility_signposting_api/services/processors/derived_values/registry.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,22 +44,40 @@ def register_default(cls, handler: DerivedValueHandler) -> None:
4444
Args:
4545
handler: The derived value handler to register
4646
"""
47-
cls._default_handlers[handler.function_name] = handler
47+
normalized_name = handler.function_name.upper()
48+
cls._default_handlers[normalized_name] = handler
49+
50+
if _RegistryHolder.instance is not None:
51+
_RegistryHolder.instance.register(handler)
52+
53+
def clear(self) -> None:
54+
"""Clear all handlers from this registry instance."""
55+
self._handlers.clear()
4856

4957
@classmethod
5058
def clear_defaults(cls) -> None:
5159
"""Clear all default handlers. Useful for testing."""
5260
cls._default_handlers.clear()
61+
if _RegistryHolder.instance is not None:
62+
_RegistryHolder.instance.clear()
5363

5464
@classmethod
5565
def get_default_handlers(cls) -> dict[str, DerivedValueHandler]:
5666
"""Get a copy of the default handlers. Useful for testing."""
5767
return cls._default_handlers.copy()
5868

69+
def set_handlers(self, handlers: dict[str, DerivedValueHandler]) -> None:
70+
"""Replace all handlers in this registry instance."""
71+
self._handlers = {name.upper(): handler for name, handler in handlers.items()}
72+
5973
@classmethod
6074
def set_default_handlers(cls, handlers: dict[str, DerivedValueHandler]) -> None:
6175
"""Set the default handlers. Useful for testing."""
62-
cls._default_handlers = handlers
76+
normalized = {name.upper(): handler for name, handler in handlers.items()}
77+
cls._default_handlers = normalized
78+
79+
if _RegistryHolder.instance is not None:
80+
_RegistryHolder.instance.set_handlers(handlers)
6381

6482
def register(self, handler: DerivedValueHandler) -> None:
6583
"""Register a derived value handler.
@@ -68,7 +86,8 @@ def register(self, handler: DerivedValueHandler) -> None:
6886
handler: The handler to register. Its function_name attribute
6987
will be used as the lookup key.
7088
"""
71-
self._handlers[handler.function_name] = handler
89+
normalized_name = handler.function_name.upper()
90+
self._handlers[normalized_name] = handler
7291

7392
def get_handler(self, function_name: str) -> DerivedValueHandler | None:
7493
"""Get a handler by function name.
@@ -151,8 +170,10 @@ def calculate(
151170
return handler.calculate(context)
152171

153172

154-
# Create a singleton instance for convenience
155-
_registry = DerivedValueRegistry()
173+
class _RegistryHolder:
174+
"""Holder for the singleton registry to avoid global statement."""
175+
176+
instance: DerivedValueRegistry | None = None
156177

157178

158179
def get_registry() -> DerivedValueRegistry:
@@ -161,4 +182,6 @@ def get_registry() -> DerivedValueRegistry:
161182
Returns:
162183
The singleton DerivedValueRegistry instance
163184
"""
164-
return _registry
185+
if _RegistryHolder.instance is None:
186+
_RegistryHolder.instance = DerivedValueRegistry()
187+
return _RegistryHolder.instance

tests/unit/services/processors/test_derived_values.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
from eligibility_signposting_api.services.processors.derived_values import (
66
AddDaysHandler,
77
DerivedValueContext,
8+
DerivedValueHandler,
89
DerivedValueRegistry,
10+
get_registry,
911
)
1012

1113

@@ -360,6 +362,29 @@ def test_default_handlers_are_registered(self):
360362
# The default ADD_DAYS handler should be registered via __init__.py
361363
assert_that(registry.has_handler("ADD_DAYS"), is_(True))
362364

365+
def test_global_registry_has_default_handlers(self):
366+
"""Test that the exported registry singleton sees default handlers."""
367+
registry = get_registry()
368+
369+
assert_that(registry.has_handler("ADD_DAYS"), is_(True))
370+
371+
def test_register_normalizes_function_name(self):
372+
"""Test that registering handlers works regardless of name casing."""
373+
374+
class LowercaseHandler(DerivedValueHandler):
375+
function_name = "custom_func"
376+
377+
def calculate(self, context: DerivedValueContext) -> str: # noqa: ARG002
378+
return ""
379+
380+
def get_source_attribute(self, target_attribute: str, function_args: str | None = None) -> str: # noqa: ARG002
381+
return target_attribute
382+
383+
registry = DerivedValueRegistry()
384+
registry.register(LowercaseHandler())
385+
386+
assert_that(registry.has_handler("CUSTOM_FUNC"), is_(True))
387+
363388
def test_clear_defaults_removes_default_handlers(self):
364389
"""Test that clear_defaults removes all default handlers."""
365390
# Save current defaults using public method

0 commit comments

Comments
 (0)