Skip to content

Commit 0382ddd

Browse files
authored
Merge pull request #254 from American-Institutes-for-Research/HEA-786/fix_cascading_api
Update the country and product endpoint filter for the cascading of L…
2 parents 17943a3 + f9ee849 commit 0382ddd

2 files changed

Lines changed: 152 additions & 1 deletion

File tree

apps/common/tests/test_viewsets.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,30 @@ def setUpTestData(cls):
2828
# import baseline factory to avoid circular depdnecies
2929
module = importlib.import_module("baseline.tests.factories")
3030
WealthGroupFactory = getattr(module, "WealthGroupFactory")
31+
LivelihoodZoneBaselineFactory = getattr(module, "LivelihoodZoneBaselineFactory")
32+
LivelihoodStrategyFactory = getattr(module, "LivelihoodStrategyFactory")
3133

3234
cls.wealth_group1 = WealthGroupFactory(livelihood_zone_baseline__livelihood_zone__country=cls.country1)
3335

36+
# Set up livelihood category cascade filter test data
37+
metadata_module = importlib.import_module("metadata.tests.factories")
38+
LivelihoodCategoryFactory = getattr(metadata_module, "LivelihoodCategoryFactory")
39+
cls.category_a = LivelihoodCategoryFactory()
40+
cls.category_b = LivelihoodCategoryFactory()
41+
LivelihoodZoneBaselineFactory(
42+
livelihood_zone__country=cls.country2,
43+
main_livelihood_category=cls.category_a,
44+
)
45+
46+
# Set up strategy_type cascade filter test data
47+
cls.strategy_type = "CropProduction"
48+
baseline_with_strategy = LivelihoodZoneBaselineFactory(livelihood_zone__country=cls.country3)
49+
LivelihoodStrategyFactory(
50+
livelihood_zone_baseline=baseline_with_strategy,
51+
strategy_type=cls.strategy_type,
52+
season__country=cls.country3,
53+
)
54+
3455
def setUp(self):
3556
self.url = reverse("country-list")
3657

@@ -105,6 +126,37 @@ def test_filter_by_has_wealthgroups(self):
105126
result = json.loads(response.content.decode("utf-8"))
106127
self.assertNotIn(self.country1.iso3166a2, result)
107128

129+
def test_filter_by_livelihood_category(self):
130+
# country2 has a baseline with category_a
131+
response = self.client.get(self.url, {"livelihood_category": self.category_a.code})
132+
self.assertEqual(response.status_code, 200)
133+
result = json.loads(response.content.decode("utf-8"))
134+
country_codes = [r["iso3166a2"] for r in result]
135+
self.assertIn(self.country2.iso3166a2, country_codes)
136+
self.assertNotIn(self.country4.iso3166a2, country_codes)
137+
138+
# category_b has no baselines
139+
response = self.client.get(self.url, {"livelihood_category": self.category_b.code})
140+
self.assertEqual(response.status_code, 200)
141+
result = json.loads(response.content.decode("utf-8"))
142+
self.assertEqual(len(result), 0)
143+
144+
def test_filter_by_strategy_type(self):
145+
# country3 has a baseline with CropProduction strategy
146+
response = self.client.get(self.url, {"strategy_type": self.strategy_type})
147+
self.assertEqual(response.status_code, 200)
148+
result = json.loads(response.content.decode("utf-8"))
149+
country_codes = [r["iso3166a2"] for r in result]
150+
self.assertIn(self.country3.iso3166a2, country_codes)
151+
self.assertNotIn(self.country4.iso3166a2, country_codes)
152+
153+
# country4 has no strategies
154+
response = self.client.get(self.url, {"strategy_type": "FoodPurchase"})
155+
self.assertEqual(response.status_code, 200)
156+
result = json.loads(response.content.decode("utf-8"))
157+
country_codes = [r["iso3166a2"] for r in result]
158+
self.assertNotIn(self.country4.iso3166a2, country_codes)
159+
108160

109161
class CurrencyViewSetTestCase(APITestCase):
110162
@classmethod
@@ -202,7 +254,28 @@ def setUpTestData(cls):
202254
livelihood_zone_baseline = LivelihoodZoneBaselineFactory(livelihood_zone__country=cls.country_a)
203255
WealthGroupFactory(livelihood_zone_baseline=livelihood_zone_baseline)
204256
cls.livelihood_strategy1 = LivelihoodStrategyFactory(
205-
livelihood_zone_baseline=livelihood_zone_baseline, product=cls.product1
257+
livelihood_zone_baseline=livelihood_zone_baseline, product=cls.product1, season__country=cls.country_a
258+
)
259+
260+
# Set up livelihood category cascade filter test data
261+
metadata_module = importlib.import_module("metadata.tests.factories")
262+
LivelihoodCategoryFactory = getattr(metadata_module, "LivelihoodCategoryFactory")
263+
cls.category_a = LivelihoodCategoryFactory()
264+
cls.category_b = LivelihoodCategoryFactory()
265+
baseline_category_a = LivelihoodZoneBaselineFactory(
266+
livelihood_zone__country=cls.country_a, main_livelihood_category=cls.category_a
267+
)
268+
LivelihoodStrategyFactory(
269+
livelihood_zone_baseline=baseline_category_a, product=cls.product1, season__country=cls.country_a
270+
)
271+
272+
# Set up strategy_type cascade filter test data
273+
cls.strategy_type_used = "CropProduction"
274+
LivelihoodStrategyFactory(
275+
livelihood_zone_baseline=livelihood_zone_baseline,
276+
product=cls.product1,
277+
strategy_type=cls.strategy_type_used,
278+
season__country=cls.country_a,
206279
)
207280

208281
def setUp(self):
@@ -285,6 +358,33 @@ def test_filter_by_country(self):
285358
self.assertEqual(len(result), 1)
286359
self.assertEqual(self.product1.cpc, result[0]["cpc"])
287360

361+
def test_filter_by_livelihood_category(self):
362+
# product1 is in a baseline with category_a
363+
response = self.client.get(self.url, {"livelihood_category": self.category_a.code})
364+
self.assertEqual(response.status_code, 200)
365+
result = json.loads(response.content.decode("utf-8"))
366+
self.assertEqual(len(result), 1)
367+
self.assertEqual(result[0]["cpc"], self.product1.cpc)
368+
369+
# category_b has no strategies
370+
response = self.client.get(self.url, {"livelihood_category": self.category_b.code})
371+
self.assertEqual(response.status_code, 200)
372+
self.assertEqual(len(json.loads(response.content.decode("utf-8"))), 0)
373+
374+
def test_filter_by_strategy_type(self):
375+
# product1 has a CropProduction strategy
376+
response = self.client.get(self.url, {"strategy_type": self.strategy_type_used})
377+
self.assertEqual(response.status_code, 200)
378+
result = json.loads(response.content.decode("utf-8"))
379+
self.assertEqual(len(result), 1)
380+
self.assertEqual(result[0]["cpc"], self.product1.cpc)
381+
382+
# product2 has no strategies
383+
response = self.client.get(self.url, {"strategy_type": "FoodPurchase"})
384+
self.assertEqual(response.status_code, 200)
385+
cpcs = [r["cpc"] for r in json.loads(response.content.decode("utf-8"))]
386+
self.assertNotIn(self.product2.cpc, cpcs)
387+
288388

289389
class UserViewSetTestCase(APITestCase):
290390
def setUp(self):

apps/common/viewsets.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ def __init__(self, *args, **kwargs):
9595
self.filters["country_code"].extra["choices"] = [
9696
(c.pk, c.iso_en_name) for c in Country.objects.all().order_by("iso_en_name")
9797
]
98+
LivelihoodStrategy = apps.get_model("baseline", "LivelihoodStrategy")
99+
self.filters["strategy_type"].extra["choices"] = LivelihoodStrategy._meta.get_field("strategy_type").choices
98100

99101
country_code = filters.MultipleChoiceFilter(
100102
field_name="iso3166a2",
@@ -119,6 +121,16 @@ def __init__(self, *args, **kwargs):
119121
)
120122

121123
has_wealthgroups = filters.BooleanFilter(method="filter_has_wealthgroups")
124+
livelihood_category = filters.ModelMultipleChoiceFilter(
125+
queryset=lambda request: apps.get_model("metadata", "LivelihoodCategory").objects.all(),
126+
method="filter_by_livelihood_category",
127+
label="Livelihood Category",
128+
)
129+
strategy_type = filters.MultipleChoiceFilter(
130+
choices=[],
131+
method="filter_by_strategy_type",
132+
label="Strategy Type",
133+
)
122134

123135
def filter_has_wealthgroups(self, queryset, name, value):
124136
if value is None:
@@ -132,6 +144,18 @@ def filter_has_wealthgroups(self, queryset, name, value):
132144
else:
133145
return queryset.exclude(Exists(wealth_group_exists))
134146

147+
def filter_by_livelihood_category(self, queryset, name, value):
148+
if not value:
149+
return queryset
150+
return queryset.filter(livelihoodzone__livelihoodzonebaseline__main_livelihood_category__in=value).distinct()
151+
152+
def filter_by_strategy_type(self, queryset, name, value):
153+
if not value:
154+
return queryset
155+
return queryset.filter(
156+
livelihoodzone__livelihoodzonebaseline__livelihood_strategies__strategy_type__in=value
157+
).distinct()
158+
135159

136160
class CountryViewSet(BaseModelViewSet):
137161
"""
@@ -289,6 +313,11 @@ class ClassifiedProductFilterSet(filters.FilterSet):
289313
The filter will display choices based on the available UnitOfMeasure objects.
290314
"""
291315

316+
def __init__(self, *args, **kwargs):
317+
super().__init__(*args, **kwargs)
318+
LivelihoodStrategy = apps.get_model("baseline", "LivelihoodStrategy")
319+
self.filters["strategy_type"].extra["choices"] = LivelihoodStrategy._meta.get_field("strategy_type").choices
320+
292321
cpc = filters.CharFilter(lookup_expr="icontains", label="CPC v2.1")
293322
description_en = filters.CharFilter(
294323
lookup_expr="icontains", label=format_lazy("{} ({})", _("Description"), _("English"))
@@ -323,6 +352,16 @@ class ClassifiedProductFilterSet(filters.FilterSet):
323352
unit_of_measure = filters.ModelChoiceFilter(queryset=UnitOfMeasure.objects.all(), field_name="unit_of_measure")
324353
has_wealthgroups = filters.BooleanFilter(method="filter_has_wealthgroups")
325354
country = CaseInsensitiveModelMultipleChoiceFilter(queryset=Country.objects.all(), method="filter_by_country")
355+
livelihood_category = filters.ModelMultipleChoiceFilter(
356+
queryset=lambda request: apps.get_model("metadata", "LivelihoodCategory").objects.all(),
357+
method="filter_by_livelihood_category",
358+
label="Livelihood Category",
359+
)
360+
strategy_type = filters.MultipleChoiceFilter(
361+
choices=[],
362+
method="filter_by_strategy_type",
363+
label="Strategy Type",
364+
)
326365

327366
def filter_has_wealthgroups(self, queryset, name, value):
328367
if value is None:
@@ -351,6 +390,18 @@ def filter_by_country(self, queryset, name, value):
351390

352391
return queryset.filter(country_queries).distinct()
353392

393+
def filter_by_livelihood_category(self, queryset, name, value):
394+
if not value:
395+
return queryset
396+
return queryset.filter(
397+
livelihood_strategies__livelihood_zone_baseline__main_livelihood_category__in=value
398+
).distinct()
399+
400+
def filter_by_strategy_type(self, queryset, name, value):
401+
if not value:
402+
return queryset
403+
return queryset.filter(livelihood_strategies__strategy_type__in=value).distinct()
404+
354405
class Meta:
355406
"""
356407
Metadata options for the ClassifiedProductFilterSet.

0 commit comments

Comments
 (0)