1010from torchjd .aggregation import MGDA , Aggregator , Mean , Random , UPGrad
1111
1212
13- @mark .parametrize ("A " , [Mean (), UPGrad (), MGDA (), Random ()])
14- def test_backward_various_aggregators (A : Aggregator ):
13+ @mark .parametrize ("aggregator " , [Mean (), UPGrad (), MGDA (), Random ()])
14+ def test_backward_various_aggregators (aggregator : Aggregator ):
1515 """Tests that backward works for various aggregators."""
1616
1717 p1 = torch .tensor ([1.0 , 2.0 ], requires_grad = True , device = DEVICE )
@@ -21,17 +21,17 @@ def test_backward_various_aggregators(A: Aggregator):
2121 y1 = torch .tensor ([- 1.0 , 1.0 ], device = DEVICE ) @ p1 + p2 .sum ()
2222 y2 = (p1 ** 2 ).sum () + p2 .norm ()
2323
24- backward ([y1 , y2 ], A )
24+ backward ([y1 , y2 ], aggregator )
2525
2626 for p in params :
2727 assert (p .grad is not None ) and (p .shape == p .grad .shape )
2828
2929
30- @mark .parametrize ("A " , [Mean (), UPGrad (), MGDA ()])
30+ @mark .parametrize ("aggregator " , [Mean (), UPGrad (), MGDA ()])
3131@mark .parametrize ("shape" , [(2 , 3 ), (2 , 6 ), (5 , 8 ), (60 , 55 ), (120 , 143 )])
3232@mark .parametrize ("manually_specify_inputs" , [True , False ])
3333def test_backward_value_is_correct (
34- A : Aggregator , shape : tuple [int , int ], manually_specify_inputs : bool
34+ aggregator : Aggregator , shape : tuple [int , int ], manually_specify_inputs : bool
3535):
3636 """
3737 Tests that the .grad value filled by backward is correct in a simple example of matrix-vector
@@ -47,15 +47,15 @@ def test_backward_value_is_correct(
4747 else :
4848 inputs = None
4949
50- backward ([output ], A , inputs = inputs )
50+ backward ([output ], aggregator , inputs = inputs )
5151
52- assert_close (input .grad , A (J ))
52+ assert_close (input .grad , aggregator (J ))
5353
5454
5555def test_backward_empty_inputs ():
5656 """Tests that backward does not fill the .grad values if no input is specified."""
5757
58- A = Mean ()
58+ aggregator = Mean ()
5959
6060 p1 = torch .tensor ([1.0 , 2.0 ], requires_grad = True , device = DEVICE )
6161 p2 = torch .tensor ([3.0 , 4.0 ], requires_grad = True , device = DEVICE )
@@ -64,7 +64,7 @@ def test_backward_empty_inputs():
6464 y1 = torch .tensor ([- 1.0 , 1.0 ], device = DEVICE ) @ p1 + p2 .sum ()
6565 y2 = (p1 ** 2 ).sum () + p2 .norm ()
6666
67- backward ([y1 , y2 ], A , inputs = [])
67+ backward ([y1 , y2 ], aggregator , inputs = [])
6868
6969 for p in params :
7070 assert p .grad is None
@@ -76,15 +76,15 @@ def test_backward_partial_inputs():
7676 specified as inputs.
7777 """
7878
79- A = Mean ()
79+ aggregator = Mean ()
8080
8181 p1 = torch .tensor ([1.0 , 2.0 ], requires_grad = True , device = DEVICE )
8282 p2 = torch .tensor ([3.0 , 4.0 ], requires_grad = True , device = DEVICE )
8383
8484 y1 = torch .tensor ([- 1.0 , 1.0 ], device = DEVICE ) @ p1 + p2 .sum ()
8585 y2 = (p1 ** 2 ).sum () + p2 .norm ()
8686
87- backward ([y1 , y2 ], A , inputs = [p1 ])
87+ backward ([y1 , y2 ], aggregator , inputs = [p1 ])
8888
8989 assert (p1 .grad is not None ) and (p1 .shape == p1 .grad .shape )
9090 assert p2 .grad is None
@@ -93,13 +93,13 @@ def test_backward_partial_inputs():
9393def test_backward_empty_tensors ():
9494 """Tests that backward raises an error when called with an empty list of tensors."""
9595
96- A = UPGrad ()
96+ aggregator = UPGrad ()
9797
9898 p1 = torch .tensor ([1.0 , 2.0 ], requires_grad = True , device = DEVICE )
9999 p2 = torch .tensor ([3.0 , 4.0 ], requires_grad = True , device = DEVICE )
100100
101101 with raises (ValueError ):
102- backward ([], A , inputs = [p1 , p2 ])
102+ backward ([], aggregator , inputs = [p1 , p2 ])
103103
104104
105105def test_backward_multiple_tensors ():
@@ -108,7 +108,7 @@ def test_backward_multiple_tensors():
108108 containing the all the values of the original tensors.
109109 """
110110
111- A = UPGrad ()
111+ aggregator = UPGrad ()
112112
113113 p1 = torch .tensor ([1.0 , 2.0 ], requires_grad = True , device = DEVICE )
114114 p2 = torch .tensor ([3.0 , 4.0 ], requires_grad = True , device = DEVICE )
@@ -117,13 +117,13 @@ def test_backward_multiple_tensors():
117117 y1 = torch .tensor ([- 1.0 , 1.0 ], device = DEVICE ) @ p1 + p2 .sum ()
118118 y2 = (p1 ** 2 ).sum () + p2 .norm ()
119119
120- backward ([y1 , y2 ], A , retain_graph = True )
120+ backward ([y1 , y2 ], aggregator , retain_graph = True )
121121
122122 param_to_grad = {p : p .grad for p in params }
123123 for p in params :
124124 p .grad = None
125125
126- backward (torch .cat ([y1 .reshape (- 1 ), y2 .reshape (- 1 )]), A )
126+ backward (torch .cat ([y1 .reshape (- 1 ), y2 .reshape (- 1 )]), aggregator )
127127
128128 for p in params :
129129 assert (p .grad == param_to_grad [p ]).all ()
@@ -133,7 +133,7 @@ def test_backward_multiple_tensors():
133133def test_backward_valid_chunk_size (chunk_size ):
134134 """Tests that backward works for various valid values of parallel_chunk_size."""
135135
136- A = UPGrad ()
136+ aggregator = UPGrad ()
137137
138138 p1 = torch .tensor ([1.0 , 2.0 ], requires_grad = True , device = DEVICE )
139139 p2 = torch .tensor ([3.0 , 4.0 ], requires_grad = True , device = DEVICE )
@@ -142,7 +142,7 @@ def test_backward_valid_chunk_size(chunk_size):
142142 y1 = torch .tensor ([- 1.0 , 1.0 ], device = DEVICE ) @ p1 + p2 .sum ()
143143 y2 = (p1 ** 2 ).sum () + p2 .norm ()
144144
145- backward ([y1 , y2 ], A , parallel_chunk_size = chunk_size , retain_graph = True )
145+ backward ([y1 , y2 ], aggregator , parallel_chunk_size = chunk_size , retain_graph = True )
146146
147147 for p in params :
148148 assert (p .grad is not None ) and (p .shape == p .grad .shape )
@@ -152,7 +152,7 @@ def test_backward_valid_chunk_size(chunk_size):
152152def test_backward_non_positive_chunk_size (chunk_size : int ):
153153 """Tests that backward raises an error when using invalid chunk sizes."""
154154
155- A = UPGrad ()
155+ aggregator = UPGrad ()
156156
157157 p1 = torch .tensor ([1.0 , 2.0 ], requires_grad = True , device = DEVICE )
158158 p2 = torch .tensor ([3.0 , 4.0 ], requires_grad = True , device = DEVICE )
@@ -161,7 +161,7 @@ def test_backward_non_positive_chunk_size(chunk_size: int):
161161 y2 = (p1 ** 2 ).sum () + p2 .norm ()
162162
163163 with raises (ValueError ):
164- backward ([y1 , y2 ], A , parallel_chunk_size = chunk_size )
164+ backward ([y1 , y2 ], aggregator , parallel_chunk_size = chunk_size )
165165
166166
167167@mark .parametrize (
@@ -174,7 +174,7 @@ def test_backward_no_retain_graph_small_chunk_size(chunk_size: int, expectation:
174174 large enough to allow differentiation of all tensors at once.
175175 """
176176
177- A = UPGrad ()
177+ aggregator = UPGrad ()
178178
179179 p1 = torch .tensor ([1.0 , 2.0 ], requires_grad = True , device = DEVICE )
180180 p2 = torch .tensor ([3.0 , 4.0 ], requires_grad = True , device = DEVICE )
@@ -183,7 +183,7 @@ def test_backward_no_retain_graph_small_chunk_size(chunk_size: int, expectation:
183183 y2 = (p1 ** 2 ).sum () + p2 .norm ()
184184
185185 with expectation :
186- backward ([y1 , y2 ], A , retain_graph = False , parallel_chunk_size = chunk_size )
186+ backward ([y1 , y2 ], aggregator , retain_graph = False , parallel_chunk_size = chunk_size )
187187
188188
189189def test_backward_fails_with_input_retaining_grad ():
@@ -198,7 +198,7 @@ def test_backward_fails_with_input_retaining_grad():
198198 c = 3 * b
199199
200200 with raises (RuntimeError ):
201- backward (tensors = c , A = UPGrad (), inputs = [b ])
201+ backward (tensors = c , aggregator = UPGrad (), inputs = [b ])
202202
203203
204204def test_backward_fails_with_non_input_retaining_grad ():
@@ -213,7 +213,7 @@ def test_backward_fails_with_non_input_retaining_grad():
213213 c = 3 * b
214214
215215 # backward itself doesn't raise the error, but it fills b.grad with a BatchedTensor
216- backward (tensors = c , A = UPGrad (), inputs = [a ])
216+ backward (tensors = c , aggregator = UPGrad (), inputs = [a ])
217217
218218 with raises (RuntimeError ):
219219 # Using such a BatchedTensor should result in an error
0 commit comments