Memory usage management (how to deal with small batches)

Hi, I’m trying to train a convolutional network for 3d data and have some trouble with memory usage.

I’ve searched the topic and found that one solution is to use batches. However, I’ve found that I have to use really small batches for them to start fitting the memory (I use batches of 5 in the “not very new” computer I usually use to work and 32 in a more suited computer).

Given so I wanted to ask two questions:

  1. Is it realistic to try to train an NN using batches with very little representation (I have to classify 20 classes. Also, I expect to have significant variance within each of those classes). If so do I need to adapt my training process in some way? For example, I’m guessing I should probably go for a small learning_rate, would that be correct?

  2. My current training loop looks as follows:
    for epoch in range(num_epoch):
    for i,(inputs,labels) in enumerate(dataloader):
    y_pred = model(inputs)
    l = loss(y_pred, labels)
    optimizer.zero_grad()
    l.backward()
    optimizer.step()
    Where I take a step for every batch of 5 or 32 data points. Could I instead do something like this?:
    batch_size = 5 # The one used in the dataloader to not overload memory
    desired_batch_size = 5000
    labels_ = Tensor()
    outs = Tensor()
    for epoch in range(num_epoch):
    for i,(inputs,labels) in enumerate(dataloader):
    y_pred = model(inputs)
    labels_ = torch.cat((labels_,labels),0)
    outs = torch.cat((out,y_pred),0)
    if i*batch_size >= desired_batch_size:
    l = loss (outs, labels_)
    optimizer.zero_grad()
    l.backward()
    optimizer.step()
    Where I’m trying to take one step once enough data points have been processed for them to be somewhat representative. Or would this just make it impossible to compute the needed partial derivatives?

Thanks for reading me