|
| 1 | +Recurrent Neural Network (RNN) |
| 2 | +============================== |
| 3 | + |
| 4 | +When training recurrent neural networks for sequence modelling, we can easily obtain one loss per |
| 5 | +element of the output sequences. If the gradients of these losses are likely to conflict, Jacobian |
| 6 | +descent can be leveraged to enhance optimization. |
| 7 | + |
| 8 | +.. code-block:: python |
| 9 | + :emphasize-lines: 5-6, 10, 17, 20 |
| 10 | +
|
| 11 | + import torch |
| 12 | + from torch.nn import RNN |
| 13 | + from torch.optim import SGD |
| 14 | +
|
| 15 | + from torchjd import backward |
| 16 | + from torchjd.aggregation import UPGrad |
| 17 | +
|
| 18 | + rnn = RNN(input_size=10, hidden_size=20, num_layers=2) |
| 19 | + optimizer = SGD(rnn.parameters(), lr=0.1) |
| 20 | + aggregator = UPGrad() |
| 21 | +
|
| 22 | + inputs = torch.randn(8, 5, 3, 10) # 8 batches of 3 sequences of length 5 and of dim 10. |
| 23 | + targets = torch.randn(8, 5, 3, 20) # 8 batches of 3 sequences of length 5 and of dim 20. |
| 24 | +
|
| 25 | + for input, target in zip(inputs, targets): |
| 26 | + output, _ = rnn(input) # output is of shape [5, 3, 20]. |
| 27 | + losses = ((output - target) ** 2).mean(dim=[1, 2]) # 1 loss per sequence element. |
| 28 | +
|
| 29 | + optimizer.zero_grad() |
| 30 | + backward(losses, aggregator, parallel_chunk_size=1) |
| 31 | + optimizer.step() |
| 32 | +
|
| 33 | +.. note:: |
| 34 | + At the time of writing, there seems to be an incompatibility between ``torch.vmap`` and |
| 35 | + ``torch.nn.RNN`` when running on CUDA (see `this issue |
| 36 | + <https://github.com/TorchJD/torchjd/issues/220>`_ for more info), so we advise to set the |
| 37 | + ``parallel_chunk_size`` to ``1`` to avoid using ``torch.vmap``. To improve performance, you can |
| 38 | + check whether ``parallel_chunk_size=None`` (maximal parallelization) works on your side. |
0 commit comments