About memory reuse in pytorch

The test code is:

import torch
import torch .nn as nn

seq0 = nn.Sequential(nn.Conv2d(3, 3, 1), nn.Conv2d(3, 3, 1), nn.Conv2d(3, 3, 1), nn.Conv2d(3, 3, 1))

inp = torch.randn(1, 3, 224, 224)

def register_hook(module):
def hook_func(module, input, output):
print(type(input), id(input[0]), type(output), id(output))
if (isinstance(module, hooked_modules)):


It prints as:
<class ‘tuple’> 139634598790880 <class ‘torch.Tensor’> 139634598874800
<class ‘tuple’> 139634598874800 <class ‘torch.Tensor’> 139634598874880
<class ‘tuple’> 139634598874880 <class ‘torch.Tensor’> 139634598874800
<class ‘tuple’> 139634598874800 <class ‘torch.Tensor’> 139634598874880

Obviously, memory reuse comes up in the network forward procedure. When pytorch reuses memory, how to calculate gradients because the intermediate Tensors may have been covered?


Do you mean how does pytorch computes gradients? It will keep all the values it needs but not necessarily as python objects as they are quite expensible to work with.

Thank you very much.
I have realized that when I saw the code:

In the forward procedure,

def forward(ctx, i):
    result = i.exp()
    return result

it saves the results for backward and the memory footprint is reused.

1 Like