Hi community, lately I am having trouble in understanding the relationship between looping statements inside
forward() function and Backpropagation Through Time (BPTT).
In brief, I would like to create a custom recurrent model, similar to RNNs/LSTMs, in which the
forward() would get sequential input data (of shape: (sequence length, batch size, input dimension)) and process the data by looping over each sequence, something similar to,
def forward(seq_data, hidden_states): for sequence in seq_data: (some simple operations ...) out = torch.sigmoid(seq_data) + hidden_states hidden_states = torch.tanh(out) ..... return out, hidden_states
Now, does PyTorch automatically detect sequential processing in forward() and perform BPTT on its own when doing
loss.backward() (i.e. without writing backward())?
Any insights and corrections are much appreciated