Skip to content

Commit 8212f57

Browse files
t-visoumith
authored andcommitted
improve RNN docs (fixes #3587) (#7669)
1 parent f7bc700 commit 8212f57

File tree

1 file changed

+28
-4
lines changed

1 file changed

+28
-4
lines changed

torch/nn/modules/rnn.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -278,13 +278,21 @@ class RNN(RNNBase):
278278
Defaults to zero if not provided.
279279
280280
Outputs: output, h_n
281-
- **output** of shape `(seq_len, batch, hidden_size * num_directions)`: tensor
281+
- **output** of shape `(seq_len, batch, num_directions * hidden_size)`: tensor
282282
containing the output features (`h_k`) from the last layer of the RNN,
283283
for each `k`. If a :class:`torch.nn.utils.rnn.PackedSequence` has
284284
been given as the input, the output will also be a packed sequence.
285+
286+
For the unpacked case, the directions can be separated
287+
using ``output.view(seq_len, batch, num_directions, hidden_size)``,
288+
with forward and backward being direction `0` and `1` respectively.
289+
Similarly, the directions can be separated in the packed case.
285290
- **h_n** (num_layers * num_directions, batch, hidden_size): tensor
286291
containing the hidden state for `k = seq_len`.
287292
293+
Like *output*, the layers can be separated using
294+
``h_n.view(num_layers, num_directions, batch, hidden_size)``.
295+
288296
Attributes:
289297
weight_ih_l[k]: the learnable input-hidden weights of the k-th layer,
290298
of shape `(hidden_size * input_size)` for `k = 0`. Otherwise, the shape is
@@ -377,12 +385,20 @@ class LSTM(RNNBase):
377385
378386
379387
Outputs: output, (h_n, c_n)
380-
- **output** of shape `(seq_len, batch, hidden_size * num_directions)`: tensor
388+
- **output** of shape `(seq_len, batch, num_directions * hidden_size)`: tensor
381389
containing the output features `(h_t)` from the last layer of the LSTM,
382390
for each t. If a :class:`torch.nn.utils.rnn.PackedSequence` has been
383391
given as the input, the output will also be a packed sequence.
392+
393+
For the unpacked case, the directions can be separated
394+
using ``output.view(seq_len, batch, num_directions, hidden_size)``,
395+
with forward and backward being direction `0` and `1` respectively.
396+
Similarly, the directions can be separated in the packed case.
384397
- **h_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
385-
containing the hidden state for `t = seq_len`
398+
containing the hidden state for `t = seq_len`.
399+
400+
Like *output*, the layers can be separated using
401+
``h_n.view(num_layers, num_directions, batch, hidden_size)`` and similarly for *c_n*.
386402
- **c_n** (num_layers * num_directions, batch, hidden_size): tensor
387403
containing the cell state for `t = seq_len`
388404
@@ -457,13 +473,21 @@ class GRU(RNNBase):
457473
Defaults to zero if not provided.
458474
459475
Outputs: output, h_n
460-
- **output** of shape `(seq_len, batch, hidden_size * num_directions)`: tensor
476+
- **output** of shape `(seq_len, batch, num_directions * hidden_size)`: tensor
461477
containing the output features h_t from the last layer of the GRU,
462478
for each t. If a :class:`torch.nn.utils.rnn.PackedSequence` has been
463479
given as the input, the output will also be a packed sequence.
480+
For the unpacked case, the directions can be separated
481+
using ``output.view(seq_len, batch, num_directions, hidden_size)``,
482+
with forward and backward being direction `0` and `1` respectively.
483+
484+
Similarly, the directions can be separated in the packed case.
464485
- **h_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
465486
containing the hidden state for `t = seq_len`
466487
488+
Like *output*, the layers can be separated using
489+
``h_n.view(num_layers, num_directions, batch, hidden_size)``.
490+
467491
Attributes:
468492
weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer
469493
(W_ir|W_iz|W_in), of shape `(3*hidden_size x input_size)`

0 commit comments

Comments
 (0)