Apparently, in a minimal example it works.
import torch
class MyModule(torch.nn.Module):
def forward(self, x):
def my_hook(grad):
self.my_grad = grad.detach().clone().cpu()
self.artifacts = 'artifacts'
return grad
x.register_hook(my_hook)
return x
model = MyModule()
x = torch.tensor(1.0).requires_grad_(True)
output = model(x)
output.backward()
model.artifacts
So I need to understand where exactly the problem happens in my more complex case.