Backpropogation Through Time with looping statements in `forward()`

Hi community, lately I am having trouble in understanding the relationship between looping statements inside forward() function and Backpropagation Through Time (BPTT).

In brief, I would like to create a custom recurrent model, similar to RNNs/LSTMs, in which the forward() would get sequential input data (of shape: (sequence length, batch size, input dimension)) and process the data by looping over each sequence, something similar to,

def forward(seq_data, hidden_states):
    for sequence in seq_data:
        (some simple operations ...)
        out = torch.sigmoid(seq_data) + hidden_states
        hidden_states = torch.tanh(out)
    return out, hidden_states

Now, does PyTorch automatically detect sequential processing in forward() and perform BPTT on its own when doing loss.backward() (i.e. without writing backward())?

Any insights and corrections are much appreciated

No, nothing is “detected”, but backprop works for implementations without inplace operations (that usually accumulate outputs in a python list first). There was a blog post on this site about how JIT speeds up rnn unrolling, but c++ coded rnns are faster. You may check SRU as an example that has python and c++ implementations.