The test code is:
import torch
import torch .nn as nn
seq0 = nn.Sequential(nn.Conv2d(3, 3, 1), nn.Conv2d(3, 3, 1), nn.Conv2d(3, 3, 1), nn.Conv2d(3, 3, 1))
inp = torch.randn(1, 3, 224, 224)
def register_hook(module):
def hook_func(module, input, output):
print(type(input), id(input[0]), type(output), id(output))
if (isinstance(module, hooked_modules)):
module.register_forward_hook(hook_func)
seq0.train()
seq0.apply(register_hook)
It prints as:
<class ‘tuple’> 139634598790880 <class ‘torch.Tensor’> 139634598874800
<class ‘tuple’> 139634598874800 <class ‘torch.Tensor’> 139634598874880
<class ‘tuple’> 139634598874880 <class ‘torch.Tensor’> 139634598874800
<class ‘tuple’> 139634598874800 <class ‘torch.Tensor’> 139634598874880
Obviously, memory reuse comes up in the network forward procedure. When pytorch reuses memory, how to calculate gradients because the intermediate Tensors may have been covered?