I’m trying to train a model in Pytorch, and I’d like to have a batch size of 8, but due to memory limitations, I can only have a batch size of at most 4. I’ve looked all around and read a lot about accumulating gradients, and it seems like the solution to my problem.
However, I seem to have trouble implementing it. Every time I run the code I get RuntimeError: Trying to backward through the graph a second time
. I don’t understand why since my code looks like all the other examples I’ve seen (unless I’m just missing something major).
One caveat is that the labels for my images are all different size, so I can’t send the output batch and the label batch into the loss function; I have to iterate over them together. This is what an epoch looks like (it’s been pared down for the sake of brevity):
# labels_batch contains labels of different sizes
for batch_idx, (inputs_batch, labels_batch) in enumerate(dataloader):
outputs_batch = model(inputs_batch)
# have to do this because labels can't be stacked into a tensor
for output, label in zip(outputs_batch, labels_batch):
output_scaled = interpolate(...) # make output match label size
loss = train_criterion(output_scaled, label) / (BATCH_SIZE * 2)
loss.backward()
if batch_idx % 2 == 1:
optimizer.step()
optimizer.zero_grad()
Is there something I’m missing? I could do the following an accumulate the losses:
# labels_batch contains labels of different sizes
for batch_idx, (inputs_batch, labels_batch) in enumerate(dataloader):
outputs_batch = model(inputs_batch)
# CHANGE: we're gonna accumulate losses manually
batch_loss = 0
# have to do this because labels can't be stacked into a tensor
for output, label in zip(outputs_batch, labels_batch):
output_scaled = interpolate(...) # make output match label size
loss = train_criterion(output_scaled, label) / (BATCH_SIZE * 2)
batch_loss += loss # CHANGE: accumulate!
# CHANGE: do backprop outside for loop
batch_loss.backward()
if batch_idx % 2 == 1:
optimizer.step()
optimizer.zero_grad()
and this works, but then my question is why can’t I call loss.backward()
in the inner loop? I’m creating a new loss object each time… If I set the batch size to 1 it would essentially be doing the same thing.
But all that aside, is there a numerical difference between calling loss.backward()
every iteration of the inner loop vs accumulating the losses into a batch loss?