Skip to content

Commit 750a2c5

Browse files
committed
updated iteration time
1 parent d924d1c commit 750a2c5

2 files changed

Lines changed: 60 additions & 27 deletions

File tree

src/eligibility_signposting_api/model/campaign_config.py

Lines changed: 41 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import re
55
import typing
66
from collections import Counter
7-
from datetime import UTC, date, datetime
7+
from datetime import UTC, date, datetime, time
88
from enum import StrEnum
99
from functools import cached_property
1010
from operator import attrgetter
@@ -33,6 +33,7 @@
3333
IterationVersion = NewType("IterationVersion", int)
3434
IterationID = NewType("IterationID", str)
3535
IterationDate = NewType("IterationDate", date)
36+
IterationTime = NewType("IterationTime", time)
3637
RuleName = NewType("RuleName", str)
3738
RuleDescription = NewType("RuleDescription", str)
3839
RulePriority = NewType("RulePriority", int)
@@ -119,19 +120,19 @@ class IterationCohort(BaseModel):
119120

120121
@cached_property
121122
def is_virtual_cohort(self) -> bool:
122-
return self.virtual == Virtual.YES
123+
return self.virtual == Virtual.YES
123124

124125
@field_validator("virtual", mode="before")
125126
@classmethod
126127
def normalize_virtual(cls, value: str) -> Virtual:
127128
if value is None:
128-
return Virtual.NO
129+
return Virtual.NO
129130
if isinstance(value, str):
130131
value = value.strip().upper()
131132
if value == "Y":
132-
return Virtual.YES
133+
return Virtual.YES
133134
if value == "N":
134-
return Virtual.NO
135+
return Virtual.NO
135136
msg = f"Invalid value for Virtual: {value!r}"
136137
raise ValueError(msg)
137138

@@ -160,8 +161,8 @@ class IterationRule(BaseModel):
160161
@field_validator("rule_stop", mode="before")
161162
def parse_yn_to_bool(cls, v: str | bool) -> bool: # noqa: N805, FBT001
162163
if isinstance(v, str):
163-
return v.upper() == "Y"
164-
return v
164+
return v.upper() == "Y"
165+
return v
165166

166167
_parent: Iteration | None = PrivateAttr(default=None)
167168

@@ -183,7 +184,7 @@ def rule_code(self) -> str:
183184
for rule_entry in self._parent.rules_mapper.values():
184185
if rule_entry and self.name in rule_entry.rule_names:
185186
rule_code = rule_entry.rule_code
186-
return rule_code or self.code or self.name
187+
return rule_code or self.code or self.name
187188

188189
@property
189190
def rule_text(self) -> str:
@@ -200,7 +201,7 @@ def rule_text(self) -> str:
200201
for rule_entry in self._parent.rules_mapper.values():
201202
if rule_entry and self.name in rule_entry.rule_names:
202203
rule_text = rule_entry.rule_text
203-
return rule_text or self.description
204+
return rule_text or self.description
204205

205206
@cached_property
206207
def parsed_cohort_labels(self) -> list[str]:
@@ -211,11 +212,11 @@ def parsed_cohort_labels(self) -> list[str]:
211212
A list of cohort labels, split by comma. If no label is set, returns an empty list.
212213
"""
213214
if not self.cohort_label:
214-
return []
215-
return [label.strip() for label in self.cohort_label.split(",") if label.strip()]
215+
return []
216+
return [label.strip() for label in self.cohort_label.split(",") if label.strip()]
216217

217218
def __str__(self) -> str:
218-
return json.dumps(self.model_dump(by_alias=True), indent=2)
219+
return json.dumps(self.model_dump(by_alias=True), indent=2)
219220

220221

221222
class AvailableAction(BaseModel):
@@ -230,7 +231,7 @@ class AvailableAction(BaseModel):
230231

231232
class ActionsMapper(RootModel[dict[str, AvailableAction]]):
232233
def get(self, key: str, default: AvailableAction | None = None) -> AvailableAction | None:
233-
return self.root.get(key, default)
234+
return self.root.get(key, default)
234235

235236

236237
class StatusText(BaseModel):
@@ -251,17 +252,18 @@ class RuleEntry(BaseModel):
251252

252253
class RulesMapper(RootModel[dict[str, RuleEntry]]):
253254
def get(self, key: str, default: RuleEntry | None = None) -> RuleEntry | None:
254-
return self.root.get(key, default)
255+
return self.root.get(key, default)
255256

256257
def values(self) -> list[RuleEntry]:
257-
return list(self.root.values())
258+
return list(self.root.values())
258259

259260

260261
class Iteration(BaseModel):
261262
id: IterationID = Field(..., alias="ID")
262263
version: IterationVersion = Field(..., alias="Version")
263264
name: IterationName = Field(..., alias="Name")
264265
iteration_date: IterationDate = Field(..., alias="IterationDate")
266+
iteration_time: IterationTime = Field(..., alias="IterationTime")
265267
iteration_number: int | None = Field(None, alias="IterationNumber")
266268
approval_minimum: int | None = Field(None, alias="ApprovalMinimum")
267269
approval_maximum: int | None = Field(None, alias="ApprovalMaximum")
@@ -287,7 +289,7 @@ def __init__(self, **data: dict[str, typing.Any]) -> None:
287289
@classmethod
288290
def parse_dates(cls, v: str | date) -> date:
289291
if isinstance(v, date):
290-
return v
292+
return v
291293

292294
v_str = str(v)
293295

@@ -296,7 +298,7 @@ def parse_dates(cls, v: str | date) -> date:
296298
raise ValueError(msg)
297299

298300
try:
299-
return datetime.strptime(v_str, "%Y%m%d").date() # noqa: DTZ007
301+
return datetime.strptime(v_str, "%Y%m%d").date() # noqa: DTZ007
300302
except ValueError as err:
301303
msg = f"Invalid date value: {v_str}. Must be a valid calendar date in YYYYMMDD format."
302304
raise ValueError(msg) from err
@@ -306,10 +308,22 @@ def parse_dates(cls, v: str | date) -> date:
306308
def serialize_dates(v: date, _info: SerializationInfo) -> str:
307309
return v.strftime("%Y%m%d")
308310

311+
@property
312+
def get_iteration_datetime(self) -> datetime:
313+
iteration_time = (
314+
self.iteration_time
315+
or getattr(getattr(self, "parent", None), "default_iteration_time", None)
316+
)
317+
318+
if iteration_time is None:
319+
raise ValueError("No iteration_time available on object or parent.default_iteration_time.")
320+
321+
return datetime.combine(self.iteration_date, iteration_time)
322+
323+
309324
def __str__(self) -> str:
310325
return json.dumps(self.model_dump(by_alias=True), indent=2)
311326

312-
313327
class CampaignConfig(BaseModel):
314328
id: CampaignID = Field(..., alias="ID")
315329
version: CampaignVersion = Field(..., alias="Version")
@@ -321,7 +335,7 @@ class CampaignConfig(BaseModel):
321335
reviewer: list[str] | None = Field(None, alias="Reviewer")
322336
iteration_frequency: Literal["X", "D", "W", "M", "Q", "A"] = Field(..., alias="IterationFrequency")
323337
iteration_type: Literal["A", "M", "S", "O"] = Field(..., alias="IterationType")
324-
iteration_time: str | None = Field(None, alias="IterationTime")
338+
default_iteration_time: IterationTime = Field(default=IterationTime(time(0, 0, 0)), alias="IterationTime")
325339
default_comms_routing: str | None = Field(None, alias="DefaultCommsRouting")
326340
start_date: StartDate = Field(..., alias="StartDate")
327341
end_date: EndDate = Field(..., alias="EndDate")
@@ -335,7 +349,7 @@ class CampaignConfig(BaseModel):
335349
@classmethod
336350
def parse_dates(cls, v: str | date) -> date:
337351
if isinstance(v, date):
338-
return v
352+
return v
339353

340354
v_str = str(v)
341355

@@ -344,22 +358,22 @@ def parse_dates(cls, v: str | date) -> date:
344358
raise ValueError(msg)
345359

346360
try:
347-
return datetime.strptime(v_str, "%Y%m%d").date() # noqa: DTZ007
361+
return datetime.strptime(v_str, "%Y%m%d").date() # noqa: DTZ007
348362
except ValueError as err:
349363
msg = f"Invalid date value: {v_str}. Must be a valid calendar date in YYYYMMDD format."
350364
raise ValueError(msg) from err
351365

352366
@field_serializer("start_date", "end_date", when_used="always")
353367
@staticmethod
354368
def serialize_dates(v: date, _info: SerializationInfo) -> str:
355-
return v.strftime("%Y%m%d")
369+
return v.strftime("%Y%m%d")
356370

357371
@model_validator(mode="after")
358372
def check_start_and_end_dates_sensible(self) -> typing.Self:
359373
if self.start_date > self.end_date:
360374
message = f"start date {self.start_date} after end date {self.end_date}"
361375
raise ValueError(message)
362-
return self
376+
return self
363377

364378
@model_validator(mode="after")
365379
def check_no_overlapping_iterations(self) -> typing.Self:
@@ -368,21 +382,21 @@ def check_no_overlapping_iterations(self) -> typing.Self:
368382
iteration_date, count = multiple_found
369383
message = f"{count} iterations with iteration date {iteration_date} in campaign {self.id}"
370384
raise ValueError(message)
371-
return self
385+
return self
372386

373387
@cached_property
374388
def campaign_live(self) -> bool:
375389
today = datetime.now(tz=UTC).date()
376-
return self.start_date <= today <= self.end_date
390+
return self.start_date <= today <= self.end_date
377391

378392
@cached_property
379393
def current_iteration(self) -> Iteration:
380394
today = datetime.now(tz=UTC).date()
381395
iterations_by_date_descending = sorted(self.iterations, key=attrgetter("iteration_date"), reverse=True)
382-
return next(i for i in iterations_by_date_descending if i.iteration_date <= today)
396+
return next(i for i in iterations_by_date_descending if i.iteration_date <= today)
383397

384398
def __str__(self) -> str:
385-
return json.dumps(self.model_dump(by_alias=True), indent=2)
399+
return json.dumps(self.model_dump(by_alias=True), indent=2)
386400

387401

388402
class Rules(BaseModel):

tests/unit/validation/test_campaign_config_validator.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,3 +318,22 @@ def test_approval_minimum_greater_than_approval_maximum_is_invalid(
318318
data["ApprovalMinimum"] = approval_min
319319
data["ApprovalMaximum"] = approval_max
320320
CampaignConfigValidation(**data)
321+
322+
def test_iteration_time_overrides_default_iteration_time(
323+
self,
324+
valid_iteration_config_with_only_mandatory_fields,
325+
):
326+
# Arrange
327+
data = valid_iteration_config_with_only_mandatory_fields.copy()
328+
data["default_iteration_time"] = "09:00:00"
329+
data["iteration_time"] = "14:30:00"
330+
config = CampaignConfigValidation(**data)
331+
332+
# Act
333+
result = config.get_iteration_datetime
334+
335+
# Assert
336+
assert result.time() == IterationTime(14, 30), (
337+
"Expected iteration_time to take precedence over default_iteration_time"
338+
)
339+

0 commit comments

Comments
 (0)