Since my model contains many operations that do not support batched operations like matrix decomposition, diag, etc. and due to variable length input for my recurrent model. I decided to simplify and just loop over elements in a batch and accumulate loss from each iteration. Some pseudo code below.
batch_loss = torch.tensor(0.0, requires_grad = True) for element in batch: output = recurrent_model(element) loss = loss_fn(output) batch_loss += loss # normalize batch loss norm_loss = batch_loss/len(batch) norm_loss.backward()
Will this make the computational graph repeat itself as many times as the elements in the batch?
Is there a better way to do this?