Implementing batched training without batching data

Hi. I have a beginner question, which I can’t seem to find a definitive answer to anywhere.

From what I understand, doing batched training has 2 benefits:

  • Batched operations can be optimized better on the GPU leading to performance increases
  • The optimizer runs after seeing multiple data samples before doing .step() which makes the model less sensitive to individual samples.

My question is as follows: If we ignore the performance benefits are these two training loops equivalent?

Proper batching:

model = SomeModel()
optimizer = SomeOptimizer(model.parameters(), ...)
dataloader = DataLoader(dataset, batch_size=10)

for batch_data in enumerate(dataloader):
    loss = model(batch_data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

Manual batching?:

model = SomeModel()
optimizer = SomeOptimizer(model.parameters(), ...)
dataloader = DataLoader(dataset, batch_size=10)

for batch_data in enumerate(dataloader):
    for i in range(10):
        loss = model(batch_data[i])
        loss.backward()

    optimizer.step()
    optimizer.zero_grad()

If they are not equivalent, what is the difference and can it be remedied without passing batched data into the model?
For context, due to time pressure I’m not able to implement the model with batched input but I would still like to simulate the effects of batched learning (except for the performance benefits)

Thanks in advance!

Hi @ldkuba. Did you find any difference? I have the same doubts. Thanks.

Depends on the loss. For example MSE by default has reduction = "mean" argument so loss is calculated for all samples in the batch and then it calculates the mean. But when you are doing this manual batching it is equivalent to calculating the sum instead of the mean. But you are basically doing gradient accumulation and other than differences in how loss functions are calculated over the batch it should be equivalent.

Also some pytorch modules work on both batched (BC*) and unbatched (C*) data. But some, especially custom ones, only work on batched data, so you might get weird results by passing unbatched data. So you may want to do loss = model(batch_data[i].unsqueeze(0)) to convert C* to BC*.