I have been working on a code to train a neural network.
and right now I’m working on a feature that finds the maximum batch size that can fit into memory
for a given model and a training set.
So here is my code:
def get_free_memory(): import GPUtil CUDA_VISIBLE_DEVICES = os.environ.get('CUDA_VISIBLE_DEVICES') memory = 0 for GPU in GPUtil.getGPUs(): if CUDA_VISIBLE_DEVICES is None or str(GPU.id) in CUDA_VISIBLE_DEVICES: memory += GPU.memoryFree return memory def no_free_mem( mem_per_sample, available ): return 5*np.array(mem_per_sample).max() > available def main(): model = PytorchModel(config_network, config_inputs, config_outputs, "") model = model.cuda() if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) max_len = config_network["max_sequence_length"] x = get_first_sample(X, max_len, model.inputs_cfg) y = get_first_sample(Y, max_len, model.outputs_cfg, output=True) optimizer = optim.Adam(filter(lambda param: param.requires_grad, model.parameters())) moremem = True batch_size = 1 prev_freemem = get_free_memory() mem_per_sample =  optimizer.zero_grad() while moremem: y_pred, _, _ = model(x) freemem = get_free_memory() if no_free_mem( mem_per_sample, freemem ): break loss, _ = model.loss(y_pred, y) freemem = min(freemem, get_free_memory()) if no_free_mem( mem_per_sample, freemem ): break loss.backward() freemem = min(freemem, get_free_memory()) if no_free_mem( mem_per_sample, freemem ): break optimizer.step() freemem = min(freemem, get_free_memory()) if no_free_mem( mem_per_sample, freemem ): break if prev_freemem - freemem > 0: mem_per_sample.append(prev_freemem - freemem) if no_free_mem( mem_per_sample, freemem ): break batch_size += 1 prev_freemem = min(prev_freemem, freemem) x = insert_sample(x) y = insert_sample(y) print("GUESSING batch_size, ", batch_size)
I compute how much GPU memory is available at each step of the forward and backward passes
And I expand the batch size iteratively until memory saturation. I keep track of how much memory is required for each sample in order to predict if I’ll be having enough memory to insert a new sample on the batch.
In my head, this should work. however, this code gets unstable behavior.
Sometimes it works fine and sometimes it crashes due to out of memory errors.
I have tried to increase the margin of required memory on the stop condition, it crashes less but nothing seems to work 100%
So I guess I must be missing something important here.
A pattern that I have observed is that crashes are specially frequent when I do multi-task learning.
And they occur in the loss.backward step most of the time.
Does someone could help me fix this problem or provide with another way to estimate memory-optimal batch sizes?
I have found some equations in this forum but none of them worked for me in all the architectures I tested.
Thank you very much!