11import torch
2- from pytest import raises
2+ from pytest import mark , raises
33from unit .conftest import DEVICE
44
55from torchjd .autojac ._transform import Jac , Jacobians
66
77from ._dict_assertions import assert_tensor_dicts_are_close
88
99
10- def test_single_input ():
10+ @mark .parametrize ("chunk_size" , [1 , 3 , None ])
11+ def test_single_input (chunk_size : int | None ):
1112 """
1213 Tests that the Jac transform works correctly for an example of multiple differentiation. Here,
1314 the function considered is: `y = [a1 * x, a2 * x]`. We want to compute the jacobians of `y` with
@@ -20,7 +21,7 @@ def test_single_input():
2021 y = torch .stack ([a1 * x , a2 * x ])
2122 input = Jacobians ({y : torch .eye (2 , device = DEVICE )})
2223
23- jac = Jac (outputs = [y ], inputs = [a1 , a2 ], chunk_size = None )
24+ jac = Jac (outputs = [y ], inputs = [a1 , a2 ], chunk_size = chunk_size )
2425
2526 jacobians = jac (input )
2627 expected_jacobians = {
@@ -31,7 +32,8 @@ def test_single_input():
3132 assert_tensor_dicts_are_close (jacobians , expected_jacobians )
3233
3334
34- def test_empty_inputs_1 ():
35+ @mark .parametrize ("chunk_size" , [1 , 3 , None ])
36+ def test_empty_inputs_1 (chunk_size : int | None ):
3537 """
3638 Tests that the Jac transform works correctly when the `inputs` parameter is an empty `Iterable`.
3739 """
@@ -41,15 +43,16 @@ def test_empty_inputs_1():
4143 y = torch .stack ([y1 , y2 ])
4244 input = Jacobians ({y : torch .eye (2 , device = DEVICE )})
4345
44- jac = Jac (outputs = [y ], inputs = [], chunk_size = None )
46+ jac = Jac (outputs = [y ], inputs = [], chunk_size = chunk_size )
4547
4648 jacobians = jac (input )
4749 expected_jacobians = {}
4850
4951 assert_tensor_dicts_are_close (jacobians , expected_jacobians )
5052
5153
52- def test_empty_inputs_2 ():
54+ @mark .parametrize ("chunk_size" , [1 , 3 , None ])
55+ def test_empty_inputs_2 (chunk_size : int | None ):
5356 """
5457 Tests that the Jac transform works correctly when the `inputs` parameter is an empty `Iterable`.
5558 """
@@ -62,7 +65,7 @@ def test_empty_inputs_2():
6265 y = torch .stack ([y1 , y2 ])
6366 input = Jacobians ({y : torch .eye (2 , device = DEVICE )})
6467
65- jac = Jac (outputs = [y ], inputs = [], chunk_size = None )
68+ jac = Jac (outputs = [y ], inputs = [], chunk_size = chunk_size )
6669
6770 jacobians = jac (input )
6871 expected_jacobians = {}
@@ -122,7 +125,8 @@ def test_two_levels():
122125 assert_tensor_dicts_are_close (jacobians , expected_jacobians )
123126
124127
125- def test_multiple_outputs_1 ():
128+ @mark .parametrize ("chunk_size" , [1 , 3 , None ])
129+ def test_multiple_outputs_1 (chunk_size : int | None ):
126130 """
127131 Tests that the Jac transform works correctly when the `outputs` contains 3 vectors.
128132 The input (jac_outputs) is not the same for all outputs, so that this test also checks that the
@@ -143,7 +147,7 @@ def test_multiple_outputs_1():
143147 jac_output3 = torch .cat ([zeros_2x2 , zeros_2x2 , identity_2x2 ])
144148 input = Jacobians ({y1 : jac_output1 , y2 : jac_output2 , y3 : jac_output3 })
145149
146- jac = Jac (outputs = [y1 , y2 , y3 ], inputs = [a1 , a2 ], chunk_size = None )
150+ jac = Jac (outputs = [y1 , y2 , y3 ], inputs = [a1 , a2 ], chunk_size = chunk_size )
147151
148152 jacobians = jac (input )
149153 zero_scalar = torch .tensor (0.0 , device = DEVICE )
@@ -155,7 +159,8 @@ def test_multiple_outputs_1():
155159 assert_tensor_dicts_are_close (jacobians , expected_jacobians )
156160
157161
158- def test_multiple_outputs_2 ():
162+ @mark .parametrize ("chunk_size" , [1 , 3 , None ])
163+ def test_multiple_outputs_2 (chunk_size : int | None ):
159164 """
160165 Same as test_multiple_outputs_1 but with different jac_outputs, so the returned jacobians are of
161166 different shapes.
@@ -175,7 +180,7 @@ def test_multiple_outputs_2():
175180 jac_output3 = torch .stack ([zeros_2 , zeros_2 , ones_2 ])
176181 input = Jacobians ({y1 : jac_output1 , y2 : jac_output2 , y3 : jac_output3 })
177182
178- jac = Jac (outputs = [y1 , y2 , y3 ], inputs = [a1 , a2 ], chunk_size = None )
183+ jac = Jac (outputs = [y1 , y2 , y3 ], inputs = [a1 , a2 ], chunk_size = chunk_size )
179184
180185 jacobians = jac (input )
181186 expected_jacobians = {
0 commit comments