Skip to content

Commit dab35b7

Browse files
committed
Merge branch 'main' into HEA-789/Create-concept-that-splits-one-expenditure-item-between-survival-or-LHP-and-Other
2 parents 3a7b777 + a4a0d9f commit dab35b7

13 files changed

Lines changed: 653 additions & 58 deletions

File tree

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Generated by Django 5.2.7 on 2025-11-05 10:09
2+
3+
import django.db.models.deletion
4+
from django.db import migrations, models
5+
6+
7+
class Migration(migrations.Migration):
8+
9+
dependencies = [
10+
("baseline", "0019_alter_foodpurchase_times_per_year_and_more"),
11+
("metadata", "0012_alter_activitylabel_activity_type_and_more"),
12+
]
13+
14+
operations = [
15+
migrations.AlterField(
16+
model_name="wealthgroup",
17+
name="wealth_group_category",
18+
field=models.ForeignKey(
19+
db_column="wealth_group_category_code",
20+
help_text="Wealth Group Category, e.g. Poor or Better Off",
21+
on_delete=django.db.models.deletion.CASCADE,
22+
related_name="wealth_groups",
23+
to="metadata.wealthgroupcategory",
24+
verbose_name="Wealth Group Category",
25+
),
26+
),
27+
]

apps/baseline/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,7 @@ class WealthGroup(common_models.Model):
505505
)
506506
wealth_group_category = models.ForeignKey(
507507
WealthGroupCategory,
508+
related_name="wealth_groups",
508509
db_column="wealth_group_category_code",
509510
on_delete=models.CASCADE,
510511
verbose_name=_("Wealth Group Category"),

apps/baseline/tests/test_viewsets.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1413,6 +1413,45 @@ def test_filter_by_country(self):
14131413
self.assertEqual(response.status_code, 200)
14141414
self.assertEqual(len(json.loads(response.content.decode("utf-8"))), 1)
14151415

1416+
def test_filter_by_product(self):
1417+
parent = ClassifiedProductFactory(cpc="K011")
1418+
product = ClassifiedProductFactory(
1419+
cpc="K0111",
1420+
description_en="my product",
1421+
common_name_en="common",
1422+
kcals_per_unit=550,
1423+
parent=parent,
1424+
aliases=["test alias"],
1425+
)
1426+
ClassifiedProductFactory(cpc="K01111")
1427+
characteristic1 = WealthCharacteristicFactory(
1428+
code="IOO", description_en="item one ownership", has_product=True
1429+
)
1430+
WealthGroupCharacteristicValueFactory(wealth_characteristic=characteristic1, product=product)
1431+
response = self.client.get(self.url, {"product": "K011"})
1432+
self.assertEqual(response.status_code, 200)
1433+
self.assertEqual(len(json.loads(response.content)), 1)
1434+
# filter by cpc
1435+
response = self.client.get(self.url, {"product": "K0111"})
1436+
self.assertEqual(response.status_code, 200)
1437+
self.assertEqual(len(json.loads(response.content)), 1)
1438+
# filter by cpc startswith
1439+
response = self.client.get(self.url, {"product": "K01111"})
1440+
self.assertEqual(response.status_code, 200)
1441+
self.assertEqual(len(json.loads(response.content)), 0)
1442+
# filter by description icontains
1443+
response = self.client.get(self.url, {"product": "my"})
1444+
self.assertEqual(response.status_code, 200)
1445+
self.assertEqual(len(json.loads(response.content.decode("utf-8"))), 1)
1446+
# filter by description
1447+
response = self.client.get(self.url, {"product": "my product"})
1448+
self.assertEqual(response.status_code, 200)
1449+
self.assertEqual(len(json.loads(response.content.decode("utf-8"))), 1)
1450+
# filter by alias
1451+
response = self.client.get(self.url, {"product": "test"})
1452+
self.assertEqual(response.status_code, 200)
1453+
self.assertEqual(len(json.loads(response.content.decode("utf-8"))), 1)
1454+
14161455

14171456
class LivelihoodStrategyViewSetTestCase(APITestCase):
14181457
@classmethod

apps/baseline/viewsets.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,15 @@ class Meta:
471471
lookup_expr="iexact",
472472
label="Country",
473473
)
474+
product = MultiFieldFilter(
475+
[
476+
*[(field, "icontains") for field in translation_fields("product__common_name")],
477+
("product__cpc", "istartswith"),
478+
*[(field, "icontains") for field in translation_fields("product__description")],
479+
("product__aliases", "icontains"),
480+
],
481+
label="Product",
482+
)
474483

475484

476485
class WealthGroupCharacteristicValueViewSet(BaseModelViewSet):

apps/common/lookups.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,8 @@ def do_lookup(
307307
if lookup_column:
308308
merge_df = merge_df.drop(["lookup_candidate", "lookup_key"], axis="columns")
309309

310-
return merge_df
310+
# Preserve the original index
311+
return merge_df.set_index(df.index)
311312

312313
def get_instances(self, df, column, related_models=None):
313314
"""

apps/common/tests/test_viewsets.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import importlib
12
import json
23

34
from rest_framework.reverse import reverse
@@ -24,6 +25,12 @@ def setUpTestData(cls):
2425
cls.country3 = CountryFactory()
2526
cls.country4 = CountryFactory()
2627

28+
# import baseline factory to avoid circular depdnecies
29+
module = importlib.import_module("baseline.tests.factories")
30+
WealthGroupFactory = getattr(module, "WealthGroupFactory")
31+
32+
cls.wealth_group1 = WealthGroupFactory(livelihood_zone_baseline__livelihood_zone__country=cls.country1)
33+
2734
def setUp(self):
2835
self.url = reverse("country-list")
2936

@@ -81,6 +88,23 @@ def test_search_by_country_name(self):
8188
result = json.loads(response.content.decode("utf-8"))
8289
self.assertEqual(len(result), 1)
8390

91+
def test_filter_by_has_wealthgroups(self):
92+
# test by has_wealthgroups set to true
93+
filter_data = {"has_wealthgroups": True}
94+
response = self.client.get(self.url, filter_data)
95+
self.assertEqual(response.status_code, 200)
96+
result = json.loads(response.content.decode("utf-8"))
97+
self.assertEqual(len(result), 1)
98+
self.assertEqual(self.country1.name, result[0]["name"])
99+
self.assertEqual(self.wealth_group1.livelihood_zone_baseline.livelihood_zone.country, self.country1)
100+
101+
# test by has_wealthgroups set to false
102+
filter_data = {"has_wealthgroups": False}
103+
response = self.client.get(self.url, filter_data)
104+
self.assertEqual(response.status_code, 200)
105+
result = json.loads(response.content.decode("utf-8"))
106+
self.assertNotIn(self.country1.iso3166a2, result)
107+
84108

85109
class CurrencyViewSetTestCase(APITestCase):
86110
@classmethod
@@ -152,6 +176,35 @@ def setUpTestData(cls):
152176
cls.product2 = ClassifiedProductFactory()
153177
cls.superuser = UserFactory(is_superuser=True, is_staff=True, is_active=True)
154178

179+
cls.country_a = CountryFactory(
180+
iso3166a2="AA",
181+
iso3166a3="AAA",
182+
iso3166n3=911,
183+
iso_en_ro_name="A Country",
184+
iso_en_name="AA Country",
185+
name="AA Country",
186+
)
187+
cls.country_b = CountryFactory(
188+
iso3166a2="BB",
189+
iso3166a3="BBB",
190+
iso3166n3=912,
191+
iso_en_ro_name="B Country",
192+
iso_en_name="BB Country",
193+
name="BB Country",
194+
)
195+
196+
# import baseline factory to avoid circular depdnecies
197+
module = importlib.import_module("baseline.tests.factories")
198+
WealthGroupFactory = getattr(module, "WealthGroupFactory")
199+
LivelihoodStrategyFactory = getattr(module, "LivelihoodStrategyFactory")
200+
LivelihoodZoneBaselineFactory = getattr(module, "LivelihoodZoneBaselineFactory")
201+
202+
livelihood_zone_baseline = LivelihoodZoneBaselineFactory(livelihood_zone__country=cls.country_a)
203+
WealthGroupFactory(livelihood_zone_baseline=livelihood_zone_baseline)
204+
cls.livelihood_strategy1 = LivelihoodStrategyFactory(
205+
livelihood_zone_baseline=livelihood_zone_baseline, product=cls.product1
206+
)
207+
155208
def setUp(self):
156209
self.url = reverse("classifiedproduct-list")
157210

@@ -193,6 +246,45 @@ def test_search_fields(self):
193246
result = json.loads(response.content.decode("utf-8"))
194247
self.assertEqual(len(result), 1)
195248

249+
def test_filter_by_has_wealthgroups(self):
250+
# test by has_wealthgroups set to true
251+
filter_data = {"has_wealthgroups": True}
252+
response = self.client.get(self.url, filter_data)
253+
self.assertEqual(response.status_code, 200)
254+
result = json.loads(response.content.decode("utf-8"))
255+
self.assertEqual(len(result), 1)
256+
self.assertEqual(self.product1.cpc, result[0]["cpc"])
257+
self.assertEqual(self.livelihood_strategy1.product, self.product1)
258+
259+
# test by has_wealthgroups set to false
260+
filter_data = {"has_wealthgroups": False}
261+
response = self.client.get(self.url, filter_data)
262+
self.assertEqual(response.status_code, 200)
263+
result = json.loads(response.content.decode("utf-8"))
264+
self.assertNotIn(self.product1.cpc, result)
265+
266+
def test_filter_by_country(self):
267+
# test filter by country
268+
filter_data = {"country": self.country_a.iso3166a2}
269+
response = self.client.get(self.url, filter_data)
270+
self.assertEqual(response.status_code, 200)
271+
result = json.loads(response.content.decode("utf-8"))
272+
self.assertEqual(len(result), 1)
273+
self.assertEqual(self.product1.cpc, result[0]["cpc"])
274+
275+
filter_data = {"country": self.country_b.iso3166a2}
276+
response = self.client.get(self.url, filter_data)
277+
self.assertEqual(response.status_code, 200)
278+
result = json.loads(response.content.decode("utf-8"))
279+
self.assertEqual(len(result), 0)
280+
281+
filter_data = {"country": self.country_a.iso3166a2.lower()}
282+
response = self.client.get(self.url, filter_data)
283+
self.assertEqual(response.status_code, 200)
284+
result = json.loads(response.content.decode("utf-8"))
285+
self.assertEqual(len(result), 1)
286+
self.assertEqual(self.product1.cpc, result[0]["cpc"])
287+
196288

197289
class UserViewSetTestCase(APITestCase):
198290
def setUp(self):

apps/common/viewsets.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
from django.apps import apps
12
from django.contrib.auth.models import User
3+
from django.db.models import Exists, OuterRef, Q
24
from django.utils.text import format_lazy
35
from django.utils.translation import gettext_lazy as _
46
from django_filters import rest_framework as filters
@@ -8,6 +10,8 @@
810
from rest_framework.pagination import PageNumberPagination
911
from rest_framework.permissions import BasePermission, IsAuthenticated
1012

13+
from common.filters import CaseInsensitiveModelMultipleChoiceFilter
14+
1115
from .fields import translation_fields
1216
from .filters import MultiFieldFilter
1317
from .models import ClassifiedProduct, Country, Currency, UnitOfMeasure, UserProfile
@@ -98,6 +102,20 @@ def __init__(self, *args, **kwargs):
98102
label="Country",
99103
)
100104

105+
has_wealthgroups = filters.BooleanFilter(method="filter_has_wealthgroups")
106+
107+
def filter_has_wealthgroups(self, queryset, name, value):
108+
if value is None:
109+
return queryset
110+
WealthGroup = apps.get_model("baseline", "WealthGroup")
111+
wealth_group_exists = WealthGroup.objects.filter(
112+
livelihood_zone_baseline__livelihood_zone__country=OuterRef("pk")
113+
)
114+
if value:
115+
return queryset.filter(Exists(wealth_group_exists))
116+
else:
117+
return queryset.exclude(Exists(wealth_group_exists))
118+
101119

102120
class CountryViewSet(BaseModelViewSet):
103121
"""
@@ -287,6 +305,35 @@ class ClassifiedProductFilterSet(filters.FilterSet):
287305
lookup_expr="icontains", label=format_lazy("{} ({})", _("Common Name"), _("Portuguese"))
288306
)
289307
unit_of_measure = filters.ModelChoiceFilter(queryset=UnitOfMeasure.objects.all(), field_name="unit_of_measure")
308+
has_wealthgroups = filters.BooleanFilter(method="filter_has_wealthgroups")
309+
country = CaseInsensitiveModelMultipleChoiceFilter(queryset=Country.objects.all(), method="filter_by_country")
310+
311+
def filter_has_wealthgroups(self, queryset, name, value):
312+
if value is None:
313+
return queryset
314+
315+
WealthGroup = apps.get_model("baseline", "WealthGroup")
316+
317+
# Get baseline IDs that have wealth groups
318+
baselines_with_wg = WealthGroup.objects.values_list("livelihood_zone_baseline_id", flat=True).distinct()
319+
320+
if value:
321+
# Return products with strategies in those baselines
322+
return queryset.filter(livelihood_strategies__livelihood_zone_baseline_id__in=baselines_with_wg).distinct()
323+
else:
324+
return queryset
325+
326+
def filter_by_country(self, queryset, name, value):
327+
if not value:
328+
return queryset
329+
330+
country_queries = Q()
331+
for country in value:
332+
country_queries |= Q(
333+
livelihood_strategies__livelihood_zone_baseline__livelihood_zone__country__iso3166a2__iexact=country.iso3166a2
334+
)
335+
336+
return queryset.filter(country_queries).distinct()
290337

291338
class Meta:
292339
"""

0 commit comments

Comments
 (0)