Is it possible to calculate the loss over slices of samples within a mini-batch and aggregate that loss?

Hi. I’m currently using a loss function that causes some resource problems for me (i.e., GPU memory).

Samples within a mini-batch often are often in the hundreds, and each minibatch has a different number of samples. Calculating the loss over these in the “normal” way causes memory issues, so one way that I thought of would be to further divide the data into slices, calculate the loss over those slices, and later aggregate that loss for the backward pass. I’m not completely sure if this would solve the fundamental issues, but thought it’d be worth a try.

Right now, what I have looks something like this:

num_samples = embeddings.shape[0]
num_batches = num_samples // batch_size

# Include remaining samples.
if num_samples / batch_size != 0:
    num_batches += 1

losses = []
current_idx = 0
for _ in range(num_batches):
    embeddings_batch = embeddings[current_idx:current_idx + batch_size]
    labels_batch = labels[current_idx:current_idx + batch_size]

    loss = self.loss_function(embeddings_batch, labels_batch)
    losses.append(loss)

    current_idx += batch_size

When I check the value for losses after running this, it looks like:

tensor(7.2189, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0., device='cuda:0', grad_fn=<SumBackward0>)
tensor(0., device='cuda:0', grad_fn=<SumBackward0>)
tensor(0., device='cuda:0', grad_fn=<SumBackward0>)
tensor(0., device='cuda:0', grad_fn=<SumBackward0>)
tensor(0., device='cuda:0', grad_fn=<SumBackward0>)
tensor(0., device='cuda:0', grad_fn=<SumBackward0>)

I’m not sure why the remaining values other than the first one are the way they are; I’d assume that they should all be pretty similar. Does anyone know why this might be happening, and how I could fix it?

Thanks.

That the grad_fn is different between the first loss and the rest is a hint that something is unexpected. Can you check the shapes of embeddings_batch and labels_batch (maybe check few values too) before they are passed to the loss function?

However, a more fundamental issue is that the current problem setup will still leave all of the intermediate tensors from the forward pass during the backward pass in memory. Unless your loss function has some esoteric scaling properties (e.g., nonlinear memory increase with batch size), I’m not sure this will actually save GPU memory. There are fancy methods to save GPU memory at the cost of a little extra recomputation (e.g., [2006.09616] Dynamic Tensor Rematerialization) but I don’t think that is something that is easy to do in upstream PyTorch out of the box.