Ah this is probably a bug.
import torch
import torch.nn as nn
a = torch.ones(2, requires_grad=True)
model = nn.Linear(2, 2)
def fn(module, grad_output):
return (grad_output[0] * 0,)
model.register_full_backward_pre_hook(fn, prepend=False)
out = model(a)
out.sum().backward()
print(a.grad) # should be 0, but its not