@@ -12,7 +12,7 @@ the parameters are updated using the resulting aggregation.
1212
1313Import several classes from ``torch `` and ``torchjd ``:
1414
15- .. code-block :: python
15+ .. testcode ::
1616
1717 import torch
1818 from torch.nn import Linear, MSELoss, ReLU, Sequential
@@ -24,14 +24,14 @@ Import several classes from ``torch`` and ``torchjd``:
2424
2525Define the model and the optimizer, as usual:
2626
27- .. code-block :: python
27+ .. testcode ::
2828
2929 model = Sequential(Linear(10, 5), ReLU(), Linear(5, 2))
3030 optimizer = SGD(model.parameters(), lr=0.1)
3131
3232Define the aggregator that will be used to combine the Jacobian matrix:
3333
34- .. code-block :: python
34+ .. testcode ::
3535
3636 aggregator = UPGrad()
3737
@@ -41,7 +41,7 @@ negatively affected by the update.
4141
4242Now that everything is defined, we can train the model. Define the input and the associated target:
4343
44- .. code-block :: python
44+ .. testcode ::
4545
4646 input = torch.randn(16, 10) # Batch of 16 random input vectors of length 10
4747 target1 = torch.randn(16) # First batch of 16 targets
@@ -51,7 +51,7 @@ Here, we generate fake inputs and labels for the sake of the example.
5151
5252We can now compute the losses associated to each element of the batch.
5353
54- .. code-block :: python
54+ .. testcode ::
5555
5656 loss_fn = MSELoss()
5757 output = model(input)
@@ -62,7 +62,7 @@ The last steps are similar to gradient descent-based optimization, but using the
6262
6363Perform the Jacobian descent backward pass:
6464
65- .. code-block :: python
65+ .. testcode ::
6666
6767 autojac.backward([loss1, loss2])
6868 jac_to_grad(model.parameters(), aggregator)
@@ -73,14 +73,14 @@ field of the parameters. It also deletes the ``.jac`` fields save some memory.
7373
7474Update each parameter based on its ``.grad `` field, using the ``optimizer ``:
7575
76- .. code-block :: python
76+ .. testcode ::
7777
7878 optimizer.step()
7979
8080The model's parameters have been updated!
8181
8282As usual, you should now reset the ``.grad `` field of each model parameter:
8383
84- .. code-block :: python
84+ .. testcode ::
8585
8686 optimizer.zero_grad()
0 commit comments