A memory usage of ~10GB would be expected for a ResNet50 with the specified input shape.
Note that the input itself, all parameters, and especially the intermediate forward activations will use device memory.
While the former two might be small the intermediates (which are needed for the gradient computation) would most likely use the majority of the memory.
A quick way of checking it would be to register forward hooks and check all output activations.
Note that this approach is not exact, as you would need to check which operations are performed inplace etc., but would at least give you an approx. estimation:
nb_acts = 0
def count_output_act(m, input, output):
global nb_acts
nb_acts += output.nelement()
[...]
for module in resnet50.modules():
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.BatchNorm2d):
module.register_forward_hook(count_output_act)
Using this approach, I get an estimation of:
nb_params = sum([p.nelement() for p in resnet50.parameters()])
nb_inputs = resnet50_input.nelement()
print('input elem: {}, param elem: {}, forward act: {}, mem usage: {}GB'.format(
nb_inputs, nb_params, nb_acts, (nb_inputs+nb_params+nb_acts)*4/1024**3))
> 10.127978801727295
input elem: 18874368, param elem: 25557032, forward act: 2787211008, mem usage: 10.548689991235733GB
which comes close to the reported value.
If you don’t want to train the model and would like to save memory by not storing the intermediates, wrap the forward pass into with torch.no_grad():
, which will use ~0.167GB.