1+ import importlib
12import json
23
34from rest_framework .reverse import reverse
@@ -24,6 +25,12 @@ def setUpTestData(cls):
2425 cls .country3 = CountryFactory ()
2526 cls .country4 = CountryFactory ()
2627
28+ # import baseline factory to avoid circular depdnecies
29+ module = importlib .import_module ("baseline.tests.factories" )
30+ WealthGroupFactory = getattr (module , "WealthGroupFactory" )
31+
32+ cls .wealth_group1 = WealthGroupFactory (livelihood_zone_baseline__livelihood_zone__country = cls .country1 )
33+
2734 def setUp (self ):
2835 self .url = reverse ("country-list" )
2936
@@ -81,6 +88,23 @@ def test_search_by_country_name(self):
8188 result = json .loads (response .content .decode ("utf-8" ))
8289 self .assertEqual (len (result ), 1 )
8390
91+ def test_filter_by_has_wealthgroups (self ):
92+ # test by has_wealthgroups set to true
93+ filter_data = {"has_wealthgroups" : True }
94+ response = self .client .get (self .url , filter_data )
95+ self .assertEqual (response .status_code , 200 )
96+ result = json .loads (response .content .decode ("utf-8" ))
97+ self .assertEqual (len (result ), 1 )
98+ self .assertEqual (self .country1 .name , result [0 ]["name" ])
99+ self .assertEqual (self .wealth_group1 .livelihood_zone_baseline .livelihood_zone .country , self .country1 )
100+
101+ # test by has_wealthgroups set to false
102+ filter_data = {"has_wealthgroups" : False }
103+ response = self .client .get (self .url , filter_data )
104+ self .assertEqual (response .status_code , 200 )
105+ result = json .loads (response .content .decode ("utf-8" ))
106+ self .assertNotIn (self .country1 .iso3166a2 , result )
107+
84108
85109class CurrencyViewSetTestCase (APITestCase ):
86110 @classmethod
@@ -152,6 +176,35 @@ def setUpTestData(cls):
152176 cls .product2 = ClassifiedProductFactory ()
153177 cls .superuser = UserFactory (is_superuser = True , is_staff = True , is_active = True )
154178
179+ cls .country_a = CountryFactory (
180+ iso3166a2 = "AA" ,
181+ iso3166a3 = "AAA" ,
182+ iso3166n3 = 911 ,
183+ iso_en_ro_name = "A Country" ,
184+ iso_en_name = "AA Country" ,
185+ name = "AA Country" ,
186+ )
187+ cls .country_b = CountryFactory (
188+ iso3166a2 = "BB" ,
189+ iso3166a3 = "BBB" ,
190+ iso3166n3 = 912 ,
191+ iso_en_ro_name = "B Country" ,
192+ iso_en_name = "BB Country" ,
193+ name = "BB Country" ,
194+ )
195+
196+ # import baseline factory to avoid circular depdnecies
197+ module = importlib .import_module ("baseline.tests.factories" )
198+ WealthGroupFactory = getattr (module , "WealthGroupFactory" )
199+ LivelihoodStrategyFactory = getattr (module , "LivelihoodStrategyFactory" )
200+ LivelihoodZoneBaselineFactory = getattr (module , "LivelihoodZoneBaselineFactory" )
201+
202+ livelihood_zone_baseline = LivelihoodZoneBaselineFactory (livelihood_zone__country = cls .country_a )
203+ WealthGroupFactory (livelihood_zone_baseline = livelihood_zone_baseline )
204+ cls .livelihood_strategy1 = LivelihoodStrategyFactory (
205+ livelihood_zone_baseline = livelihood_zone_baseline , product = cls .product1
206+ )
207+
155208 def setUp (self ):
156209 self .url = reverse ("classifiedproduct-list" )
157210
@@ -193,6 +246,45 @@ def test_search_fields(self):
193246 result = json .loads (response .content .decode ("utf-8" ))
194247 self .assertEqual (len (result ), 1 )
195248
249+ def test_filter_by_has_wealthgroups (self ):
250+ # test by has_wealthgroups set to true
251+ filter_data = {"has_wealthgroups" : True }
252+ response = self .client .get (self .url , filter_data )
253+ self .assertEqual (response .status_code , 200 )
254+ result = json .loads (response .content .decode ("utf-8" ))
255+ self .assertEqual (len (result ), 1 )
256+ self .assertEqual (self .product1 .cpc , result [0 ]["cpc" ])
257+ self .assertEqual (self .livelihood_strategy1 .product , self .product1 )
258+
259+ # test by has_wealthgroups set to false
260+ filter_data = {"has_wealthgroups" : False }
261+ response = self .client .get (self .url , filter_data )
262+ self .assertEqual (response .status_code , 200 )
263+ result = json .loads (response .content .decode ("utf-8" ))
264+ self .assertNotIn (self .product1 .cpc , result )
265+
266+ def test_filter_by_country (self ):
267+ # test filter by country
268+ filter_data = {"country" : self .country_a .iso3166a2 }
269+ response = self .client .get (self .url , filter_data )
270+ self .assertEqual (response .status_code , 200 )
271+ result = json .loads (response .content .decode ("utf-8" ))
272+ self .assertEqual (len (result ), 1 )
273+ self .assertEqual (self .product1 .cpc , result [0 ]["cpc" ])
274+
275+ filter_data = {"country" : self .country_b .iso3166a2 }
276+ response = self .client .get (self .url , filter_data )
277+ self .assertEqual (response .status_code , 200 )
278+ result = json .loads (response .content .decode ("utf-8" ))
279+ self .assertEqual (len (result ), 0 )
280+
281+ filter_data = {"country" : self .country_a .iso3166a2 .lower ()}
282+ response = self .client .get (self .url , filter_data )
283+ self .assertEqual (response .status_code , 200 )
284+ result = json .loads (response .content .decode ("utf-8" ))
285+ self .assertEqual (len (result ), 1 )
286+ self .assertEqual (self .product1 .cpc , result [0 ]["cpc" ])
287+
196288
197289class UserViewSetTestCase (APITestCase ):
198290 def setUp (self ):
0 commit comments