@@ -28,9 +28,30 @@ def setUpTestData(cls):
2828 # import baseline factory to avoid circular depdnecies
2929 module = importlib .import_module ("baseline.tests.factories" )
3030 WealthGroupFactory = getattr (module , "WealthGroupFactory" )
31+ LivelihoodZoneBaselineFactory = getattr (module , "LivelihoodZoneBaselineFactory" )
32+ LivelihoodStrategyFactory = getattr (module , "LivelihoodStrategyFactory" )
3133
3234 cls .wealth_group1 = WealthGroupFactory (livelihood_zone_baseline__livelihood_zone__country = cls .country1 )
3335
36+ # Set up livelihood category cascade filter test data
37+ metadata_module = importlib .import_module ("metadata.tests.factories" )
38+ LivelihoodCategoryFactory = getattr (metadata_module , "LivelihoodCategoryFactory" )
39+ cls .category_a = LivelihoodCategoryFactory ()
40+ cls .category_b = LivelihoodCategoryFactory ()
41+ LivelihoodZoneBaselineFactory (
42+ livelihood_zone__country = cls .country2 ,
43+ main_livelihood_category = cls .category_a ,
44+ )
45+
46+ # Set up strategy_type cascade filter test data
47+ cls .strategy_type = "CropProduction"
48+ baseline_with_strategy = LivelihoodZoneBaselineFactory (livelihood_zone__country = cls .country3 )
49+ LivelihoodStrategyFactory (
50+ livelihood_zone_baseline = baseline_with_strategy ,
51+ strategy_type = cls .strategy_type ,
52+ season__country = cls .country3 ,
53+ )
54+
3455 def setUp (self ):
3556 self .url = reverse ("country-list" )
3657
@@ -105,6 +126,37 @@ def test_filter_by_has_wealthgroups(self):
105126 result = json .loads (response .content .decode ("utf-8" ))
106127 self .assertNotIn (self .country1 .iso3166a2 , result )
107128
129+ def test_filter_by_livelihood_category (self ):
130+ # country2 has a baseline with category_a
131+ response = self .client .get (self .url , {"livelihood_category" : self .category_a .code })
132+ self .assertEqual (response .status_code , 200 )
133+ result = json .loads (response .content .decode ("utf-8" ))
134+ country_codes = [r ["iso3166a2" ] for r in result ]
135+ self .assertIn (self .country2 .iso3166a2 , country_codes )
136+ self .assertNotIn (self .country4 .iso3166a2 , country_codes )
137+
138+ # category_b has no baselines
139+ response = self .client .get (self .url , {"livelihood_category" : self .category_b .code })
140+ self .assertEqual (response .status_code , 200 )
141+ result = json .loads (response .content .decode ("utf-8" ))
142+ self .assertEqual (len (result ), 0 )
143+
144+ def test_filter_by_strategy_type (self ):
145+ # country3 has a baseline with CropProduction strategy
146+ response = self .client .get (self .url , {"strategy_type" : self .strategy_type })
147+ self .assertEqual (response .status_code , 200 )
148+ result = json .loads (response .content .decode ("utf-8" ))
149+ country_codes = [r ["iso3166a2" ] for r in result ]
150+ self .assertIn (self .country3 .iso3166a2 , country_codes )
151+ self .assertNotIn (self .country4 .iso3166a2 , country_codes )
152+
153+ # country4 has no strategies
154+ response = self .client .get (self .url , {"strategy_type" : "FoodPurchase" })
155+ self .assertEqual (response .status_code , 200 )
156+ result = json .loads (response .content .decode ("utf-8" ))
157+ country_codes = [r ["iso3166a2" ] for r in result ]
158+ self .assertNotIn (self .country4 .iso3166a2 , country_codes )
159+
108160
109161class CurrencyViewSetTestCase (APITestCase ):
110162 @classmethod
@@ -202,7 +254,28 @@ def setUpTestData(cls):
202254 livelihood_zone_baseline = LivelihoodZoneBaselineFactory (livelihood_zone__country = cls .country_a )
203255 WealthGroupFactory (livelihood_zone_baseline = livelihood_zone_baseline )
204256 cls .livelihood_strategy1 = LivelihoodStrategyFactory (
205- livelihood_zone_baseline = livelihood_zone_baseline , product = cls .product1
257+ livelihood_zone_baseline = livelihood_zone_baseline , product = cls .product1 , season__country = cls .country_a
258+ )
259+
260+ # Set up livelihood category cascade filter test data
261+ metadata_module = importlib .import_module ("metadata.tests.factories" )
262+ LivelihoodCategoryFactory = getattr (metadata_module , "LivelihoodCategoryFactory" )
263+ cls .category_a = LivelihoodCategoryFactory ()
264+ cls .category_b = LivelihoodCategoryFactory ()
265+ baseline_category_a = LivelihoodZoneBaselineFactory (
266+ livelihood_zone__country = cls .country_a , main_livelihood_category = cls .category_a
267+ )
268+ LivelihoodStrategyFactory (
269+ livelihood_zone_baseline = baseline_category_a , product = cls .product1 , season__country = cls .country_a
270+ )
271+
272+ # Set up strategy_type cascade filter test data
273+ cls .strategy_type_used = "CropProduction"
274+ LivelihoodStrategyFactory (
275+ livelihood_zone_baseline = livelihood_zone_baseline ,
276+ product = cls .product1 ,
277+ strategy_type = cls .strategy_type_used ,
278+ season__country = cls .country_a ,
206279 )
207280
208281 def setUp (self ):
@@ -285,6 +358,33 @@ def test_filter_by_country(self):
285358 self .assertEqual (len (result ), 1 )
286359 self .assertEqual (self .product1 .cpc , result [0 ]["cpc" ])
287360
361+ def test_filter_by_livelihood_category (self ):
362+ # product1 is in a baseline with category_a
363+ response = self .client .get (self .url , {"livelihood_category" : self .category_a .code })
364+ self .assertEqual (response .status_code , 200 )
365+ result = json .loads (response .content .decode ("utf-8" ))
366+ self .assertEqual (len (result ), 1 )
367+ self .assertEqual (result [0 ]["cpc" ], self .product1 .cpc )
368+
369+ # category_b has no strategies
370+ response = self .client .get (self .url , {"livelihood_category" : self .category_b .code })
371+ self .assertEqual (response .status_code , 200 )
372+ self .assertEqual (len (json .loads (response .content .decode ("utf-8" ))), 0 )
373+
374+ def test_filter_by_strategy_type (self ):
375+ # product1 has a CropProduction strategy
376+ response = self .client .get (self .url , {"strategy_type" : self .strategy_type_used })
377+ self .assertEqual (response .status_code , 200 )
378+ result = json .loads (response .content .decode ("utf-8" ))
379+ self .assertEqual (len (result ), 1 )
380+ self .assertEqual (result [0 ]["cpc" ], self .product1 .cpc )
381+
382+ # product2 has no strategies
383+ response = self .client .get (self .url , {"strategy_type" : "FoodPurchase" })
384+ self .assertEqual (response .status_code , 200 )
385+ cpcs = [r ["cpc" ] for r in json .loads (response .content .decode ("utf-8" ))]
386+ self .assertNotIn (self .product2 .cpc , cpcs )
387+
288388
289389class UserViewSetTestCase (APITestCase ):
290390 def setUp (self ):
0 commit comments