Computational graph: Looping over elements in a batch


Since my model contains many operations that do not support batched operations like matrix decomposition, diag, etc. and due to variable length input for my recurrent model. I decided to simplify and just loop over elements in a batch and accumulate loss from each iteration. Some pseudo code below.

batch_loss = torch.tensor(0.0, requires_grad = True)
for element in batch:
    output = recurrent_model(element)
    loss = loss_fn(output)
    batch_loss += loss

# normalize batch loss
norm_loss = batch_loss/len(batch)

Will this make the computational graph repeat itself as many times as the elements in the batch?
Is there a better way to do this?


Read this

Use detach() if you do want to keep the tensor in the graph anymore.