Batch learning on large datasets

I am trying to do batch learning on a large dataset that will not fit on the GPU. I am not sure where to clear the gradients and compute the loss. Is this the correct way to use a DataLoader and move the data to the GPU in pieces for batch learning?

train_dataset = torch.utils.data.TensorDataset(X_train, y_train)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size, shuffle=True)

for epoch in range(n_epochs):
    for loc_X_train,loc_y_train in train_dataloader:
        # move data to the device
        loc_X_train, loc_y_train = loc_X_train.cuda(), loc_y_train.cuda()

        # clear any calculated gradients
        optimizer.zero_grad()

        # forward pass, compute outputs
        outputs = model.forward(loc_X_train)

        # compute loss
        loss = loss_function(outputs, loc_y_train)

        # backward pass, compute gradients
        loss.backward()

    # update learnable parameters
    optimizer.step()

The DataLoader loop looks alright.
You shouldn’t call model.forward, but the model directly via outputs = model(loc_X_train) so that registered hooks will be properly called.
Also, usually you would update the parameters after the backward call in each iteration not once per epoch, so you might want to call optimizer.step() inside the DataLoader loop.

In batch learning, the weights are only updated once per epoch, so the optimizer.step() call needs to be at the end of the outer loop. I’m starting to think that the optimizer.zero_grad() call should be just before the inner loop.

Yes, in that case you shouldn’t zero out the gradients in each iteration and note that each backward() call would accumulate the gradients. In batch training the gradients are usually calculated using the mean, so you might also need to scale the gradients before applying the optimizer.step() operation.

What’s the best way to do this? If I put the following between the loss calculation and backward step, would it take care of it?

if loss_function.reduction != "sum":
    loss *= len(outputs)

No, this would increase the loss even further.
Often you are scaling the loss via:

loss = loss / accumulation_steps

in each iteration.