Relationship between GPU Memory Usage and Batch Size

I wonder does the GPU memory usage rough has a linear relationship with the batch size used in training?

I was fine tune ResNet152. With a batch size 8, the total GPU memory used is around 4G and when the batch size is increased to 16 for training, the total GPU memory used is around 6G. The model itself takes about 2G. It seems to me the GPU memory consumption of training ResNet 152 is approximately 2G + 2G * batch_size / 8?

The batch size would increase the activation sizes during the forward pass, while the model parameter (and gradients) would still use the same amount of memory as they are not depending on the used batch size. This post explains the memory usage in more detail.

2 Likes

This is a function that I wrote to calculate the activation size of a network, in order to find out how much can you increase the batch size:

total_output_elements = 0
def calc_total_activation_size(model, call_the_network_function):
    global total_output_elements
    total_output_elements = 0

    def hook(module, input, output):
        global total_output_elements
        total_output_elements += output.numel()
        
    handle = torch.nn.modules.module.register_module_forward_hook(hook)
    result = call_the_network_function()
    handle.remove()
    return result, total_output_elements