Backpropogation Through Time with looping statements in `forward()`

Hi community, lately I am having trouble in understanding the relationship between looping statements inside forward() function and Backpropagation Through Time (BPTT).

In brief, I would like to create a custom recurrent model, similar to RNNs/LSTMs, in which the forward() would get sequential input data (of shape: (sequence length, batch size, input dimension)) and process the data by looping over each sequence, something similar to,

def forward(seq_data, hidden_states):
    for sequence in seq_data:
        (some simple operations ...)
        out = torch.sigmoid(seq_data) + hidden_states
        hidden_states = torch.tanh(out)
        .....
    return out, hidden_states

Now, does PyTorch automatically detect sequential processing in forward() and perform BPTT on its own when doing loss.backward() (i.e. without writing backward())?

Any insights and corrections are much appreciated :slight_smile:

1 Like

No, nothing is “detected”, but backprop works for implementations without inplace operations (that usually accumulate outputs in a python list first). There was a blog post on this site about how JIT speeds up rnn unrolling, but c++ coded rnns are faster. You may check SRU as an example that has python and c++ implementations.