Skip to content

Commit 1e0293d

Browse files
committed
Add as_of_date filer to LivelihoodZoneBaselineFilterSet and corresponding tests test_as_of_date_filter_returns_valid_baselines see HEA-914
1 parent 0d4321e commit 1e0293d

6 files changed

Lines changed: 165 additions & 8 deletions

File tree

apps/baseline/models.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Models for managing HEA Baseline Surveys
33
"""
44

5+
import datetime
56
import numbers
67

78
from django.conf import settings
@@ -131,7 +132,22 @@ class ExtraMeta:
131132
identifier = ["code"]
132133

133134

134-
class LivelihoodZoneBaselineManager(common_models.IdentifierManager):
135+
class LivelihoodZoneBaselineQuerySet(models.QuerySet):
136+
"""
137+
QuerySet for LivelihoodZoneBaseline that provides temporal filtering methods.
138+
"""
139+
140+
def filter_current(self, as_of_date=None):
141+
if not as_of_date:
142+
as_of_date = datetime.date.today()
143+
return self.filter(models.Q(valid_to_date__isnull=True) | models.Q(valid_to_date__gte=as_of_date))
144+
145+
def current_all(self, as_of_date=None):
146+
# Return all the baselines that are valid as of the date specified.
147+
return self.filter_current(as_of_date).all()
148+
149+
150+
class LivelihoodZoneBaselineManager(common_models.IdentifierManager.from_queryset(LivelihoodZoneBaselineQuerySet)):
135151
def get_by_natural_key(self, code: str, reference_year_end_date: str):
136152
return self.get(livelihood_zone__code=code, reference_year_end_date=reference_year_end_date)
137153

apps/baseline/tests/factories.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,9 @@ class Meta:
119119
bss_language = factory.Iterator(["en", "pt", "es", "ar", "fr"])
120120
reference_year_start_date = factory.LazyAttribute(lambda o: o.reference_year_end_date - relativedelta(years=1))
121121
reference_year_end_date = factory.Sequence(lambda n: datetime.date(1900, 1, 1) + datetime.timedelta(days=n + 10))
122-
valid_from_date = factory.Sequence(lambda n: datetime.date(1900, 1, 1) + datetime.timedelta(days=n))
123-
valid_to_date = factory.Sequence(lambda n: datetime.date(1900, 1, 1) + datetime.timedelta(days=n + 10))
122+
# Default to None so factory-created baselines are always valid regardless of as_of_date filter
123+
valid_from_date = None
124+
valid_to_date = None
124125
population_source = factory.Sequence(lambda n: f"population_source {n}")
125126
population_estimate = fuzzy.FuzzyInteger(500, 1000000)
126127
currency = factory.SubFactory(CurrencyFactory)

apps/baseline/tests/test_viewsets.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,63 @@ def test_filter_by_wealth_characteristic(self):
595595
self.assertEqual(response.status_code, 200)
596596
self.assertEqual(len(response.json()), 1)
597597

598+
def test_as_of_date_filter_returns_valid_baselines(self):
599+
# Test that the as_of_date filter returns valid baselines as of the specified date.
600+
today = datetime.date.today()
601+
expired_baseline = LivelihoodZoneBaselineFactory(
602+
valid_from_date=today - datetime.timedelta(days=365),
603+
valid_to_date=today - datetime.timedelta(days=30),
604+
)
605+
current_baseline = LivelihoodZoneBaselineFactory(
606+
valid_from_date=today - datetime.timedelta(days=30),
607+
valid_to_date=today + datetime.timedelta(days=365),
608+
)
609+
# Test default behavior (as of today) - should exclude expired_baseline
610+
response = self.client.get(self.url)
611+
self.assertEqual(response.status_code, 200)
612+
baseline_ids = [b["id"] for b in response.json()]
613+
self.assertIn(current_baseline.id, baseline_ids)
614+
self.assertNotIn(expired_baseline.id, baseline_ids)
615+
616+
# Test with a past date before the expired_baseline's valid_to_date
617+
past_date = today - datetime.timedelta(days=180)
618+
response = self.client.get(self.url, {"as_of_date": past_date.isoformat()})
619+
self.assertEqual(response.status_code, 200)
620+
baseline_ids = [b["id"] for b in response.json()]
621+
self.assertIn(expired_baseline.id, baseline_ids)
622+
self.assertIn(current_baseline.id, baseline_ids)
623+
624+
# Test with a future date - expired_baseline should still be excluded
625+
future_date = today + datetime.timedelta(days=60)
626+
response = self.client.get(self.url, {"as_of_date": future_date.isoformat()})
627+
self.assertEqual(response.status_code, 200)
628+
baseline_ids = [b["id"] for b in response.json()]
629+
self.assertNotIn(expired_baseline.id, baseline_ids)
630+
self.assertIn(current_baseline.id, baseline_ids)
631+
632+
# test as of date filter handles null dates
633+
baseline_no_from = LivelihoodZoneBaselineFactory(
634+
valid_from_date=None,
635+
valid_to_date=today + datetime.timedelta(days=365),
636+
)
637+
# Create a baseline with null valid_to_date (valid indefinitely)
638+
baseline_no_to = LivelihoodZoneBaselineFactory(
639+
valid_from_date=today - datetime.timedelta(days=365),
640+
valid_to_date=None,
641+
)
642+
# Create a baseline with both dates null (always valid)
643+
baseline_no_dates = LivelihoodZoneBaselineFactory(
644+
valid_from_date=None,
645+
valid_to_date=None,
646+
)
647+
# Test default behavior - all three should be returned
648+
response = self.client.get(self.url)
649+
self.assertEqual(response.status_code, 200)
650+
baseline_ids = [b["id"] for b in response.json()]
651+
self.assertIn(baseline_no_from.id, baseline_ids)
652+
self.assertIn(baseline_no_to.id, baseline_ids)
653+
self.assertIn(baseline_no_dates.id, baseline_ids)
654+
598655

599656
class LivelihoodZoneBaselineFacetedSearchViewTestCase(APITestCase):
600657
def setUp(self):
@@ -641,7 +698,8 @@ def test_search_with_product(self):
641698
# Apply the filters to the baseline
642699
baseline_url = reverse("livelihoodzonebaseline-list")
643700
response = self.client.get(
644-
baseline_url, {search_data["products"][0]["filter"]: search_data["products"][0]["value"]}
701+
baseline_url,
702+
{search_data["products"][0]["filter"]: search_data["products"][0]["value"]},
645703
)
646704
self.assertEqual(response.status_code, 200)
647705
self.assertEqual(len(json.loads(response.content)), 2)
@@ -650,7 +708,10 @@ def test_search_with_product(self):
650708
self.assertTrue(any(d["name"] == self.baseline3.name for d in data))
651709
self.assertFalse(any(d["name"] == self.baseline2.name for d in data))
652710

653-
response = self.client.get(baseline_url, {search_data["items"][0]["filter"]: search_data["items"][0]["value"]})
711+
response = self.client.get(
712+
baseline_url,
713+
{search_data["items"][0]["filter"]: search_data["items"][0]["value"]},
714+
)
654715
self.assertEqual(response.status_code, 200)
655716
self.assertEqual(len(json.loads(response.content)), 1)
656717
data = json.loads(response.content)

apps/baseline/viewsets.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from rest_framework.views import APIView
1414

1515
from common.fields import translation_fields
16-
from common.filters import MultiFieldFilter, UpperCaseFilter
16+
from common.filters import DefaultingDateFilter, MultiFieldFilter, UpperCaseFilter
1717
from common.viewsets import AggregatingViewSet, BaseModelViewSet
1818
from metadata.models import WealthGroupCategory
1919

@@ -205,6 +205,10 @@ class Meta:
205205
wealth_characteristic = CharFilter(
206206
method="filter_by_wealth_characteristic", label="Filter by Wealth Characteristic"
207207
)
208+
as_of_date = DefaultingDateFilter(
209+
label="As of Date",
210+
help_text="Filter baselines valid as of this date (YYYY-MM-DD format or special values like 'today').",
211+
)
208212

209213
def filter_by_product(self, queryset, name, value):
210214
"""
@@ -250,6 +254,14 @@ def filter_by_wealth_characteristic(self, queryset, name, value):
250254
class LivelihoodZoneBaselineViewSet(BaseModelViewSet):
251255
"""
252256
API endpoint that allows livelihood zone baselines to be viewed or edited.
257+
258+
By default, this endpoint returns only baselines that are currently valid (as of today's date).
259+
This behavior can be controlled using the `as_of_date` query parameter:
260+
261+
- No `as_of_date` parameter: Returns baselines valid as of today
262+
- `as_of_date=YYYY-MM-DD`: Returns baselines valid as of the specified date
263+
- `as_of_date=today`: Returns baselines valid as of today (explicit)
264+
- `as_of_date=<special_value>`: Supports special values like 'last_month', 'one_year_ago', etc.
253265
"""
254266

255267
queryset = LivelihoodZoneBaseline.objects.select_related(
@@ -266,6 +278,22 @@ class LivelihoodZoneBaselineViewSet(BaseModelViewSet):
266278
ordering_fields = ["livelihood_zone__code", "reference_year_end_date"]
267279
ordering = ["livelihood_zone__code", "reference_year_end_date"]
268280

281+
def get_queryset(self):
282+
"""
283+
Override get_queryset to apply default as_of_date filter for list actions.
284+
"""
285+
queryset = super().get_queryset()
286+
287+
if self.action != "list":
288+
return queryset
289+
290+
as_of_date_param = self.request.query_params.get("as_of_date", None)
291+
292+
if as_of_date_param is None:
293+
queryset = queryset.filter_current()
294+
295+
return queryset
296+
269297
def get_serializer_class(self):
270298
if self.request.accepted_renderer.format == "geojson":
271299
return LivelihoodZoneBaselineGeoSerializer # Use GeoFeatureModelSerializer for GeoJSON

apps/common/filters.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,16 @@
88
from django.core.validators import EMPTY_VALUES
99
from django.db.models import F, Func, Q
1010
from django.forms import TextInput
11-
from django.forms.fields import ChoiceField, Field, MultipleChoiceField
11+
from django.forms.fields import ChoiceField, DateField, Field, MultipleChoiceField
1212
from django.forms.models import ModelMultipleChoiceField
1313
from django.utils.datastructures import MultiValueDict
1414
from django.utils.encoding import force_str
1515
from django_filters import ModelMultipleChoiceFilter, MultipleChoiceFilter
16-
from django_filters.filters import BooleanFilter, CharFilter, ChoiceFilter
16+
from django_filters.filters import BooleanFilter, CharFilter, ChoiceFilter, DateFilter
1717
from rest_framework.filters import OrderingFilter
1818

19+
from .utils import DEFAULT_DATES
20+
1921
logger = logging.getLogger(__name__)
2022

2123

@@ -238,3 +240,33 @@ def filter(self, qs, value):
238240
method = qs.exclude if exclude else qs.filter
239241

240242
return method(**{self.field_name: ""})
243+
244+
245+
class DefaultingDateField(DateField):
246+
"""
247+
A date field that accepts defaults like "last_month", "today"
248+
"""
249+
250+
def to_python(self, value):
251+
if value in DEFAULT_DATES:
252+
value = DEFAULT_DATES[value]()
253+
return super().to_python(value)
254+
255+
256+
class DefaultingDateFilter(DateFilter):
257+
"""
258+
A date filter that accepts defaults like "today"
259+
"""
260+
261+
field_class = DefaultingDateField
262+
263+
def filter(self, qs, value):
264+
if value:
265+
if self.lookup_expr in ["lte", "gte"]:
266+
# period_date filter for start and end date with either gte or lte expression will fall here.
267+
query = Q()
268+
query = Q(**{self.field_name + "__" + self.lookup_expr: value})
269+
return qs.filter(query)
270+
else:
271+
return qs.current_all(value)
272+
return qs

apps/common/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pathlib import Path
1010

1111
import pandas as pd
12+
from dateutil.relativedelta import relativedelta
1213
from django.apps import apps
1314
from django.db.migrations.operations.base import Operation
1415
from django.forms.models import modelform_factory
@@ -18,6 +19,24 @@
1819
logger = logging.getLogger(__name__)
1920

2021

22+
DEFAULT_DATES = {
23+
"ten_years_ago": lambda: datetime.date.today() + relativedelta(years=-10, day=1),
24+
"five_years_ago": lambda: datetime.date.today() + relativedelta(years=-5, day=1),
25+
"three_years_ago": lambda: datetime.date.today() + relativedelta(years=-3, day=1),
26+
"two_years_ago": lambda: datetime.date.today() + relativedelta(years=-2, day=1),
27+
"one_year_ago": lambda: datetime.date.today() + relativedelta(years=-1, day=1),
28+
"last_month_start": lambda: datetime.date.today() + relativedelta(months=-1, day=1),
29+
"last_month": lambda: datetime.date.today().replace(day=1) - datetime.timedelta(days=1),
30+
"second_last_month_start": lambda: datetime.date.today() + relativedelta(months=-2, day=1),
31+
"second_last_month": lambda: datetime.date.today() + relativedelta(months=-2, day=31),
32+
"third_last_month": lambda: datetime.date.today() + relativedelta(months=-3, day=31),
33+
"fourth_last_month": lambda: datetime.date.today() + relativedelta(months=-4, day=31),
34+
"sixth_last_month": lambda: datetime.date.today() + relativedelta(months=-6, day=31),
35+
"today": lambda: datetime.date.today(),
36+
"tomorrow": lambda: datetime.date.today() + relativedelta(days=1),
37+
}
38+
39+
2140
class UnicodeCsvReader(object):
2241
# TODO: Should check if it works without encoding and if this class needed on Python 3 at all.
2342
def __init__(self, f, encoding="utf-8", **kwargs):

0 commit comments

Comments
 (0)