Hi community, lately I am having trouble in understanding the relationship between looping statements inside forward()
function and Backpropagation Through Time (BPTT).
In brief, I would like to create a custom recurrent model, similar to RNNs/LSTMs, in which the forward()
would get sequential input data (of shape: (sequence length, batch size, input dimension)) and process the data by looping over each sequence, something similar to,
def forward(seq_data, hidden_states):
for sequence in seq_data:
(some simple operations ...)
out = torch.sigmoid(seq_data) + hidden_states
hidden_states = torch.tanh(out)
.....
return out, hidden_states
Now, does PyTorch automatically detect sequential processing in forward() and perform BPTT on its own when doing loss.backward()
(i.e. without writing backward())?
Any insights and corrections are much appreciated