Hi,
Since my model contains many operations that do not support batched operations like matrix decomposition, diag, etc. and due to variable length input for my recurrent model. I decided to simplify and just loop over elements in a batch and accumulate loss from each iteration. Some pseudo code below.
batch_loss = torch.tensor(0.0, requires_grad = True)
for element in batch:
output = recurrent_model(element)
loss = loss_fn(output)
batch_loss += loss
# normalize batch loss
norm_loss = batch_loss/len(batch)
norm_loss.backward()
Will this make the computational graph repeat itself as many times as the elements in the batch?
Is there a better way to do this?
Thanks