Understanding gradient calculation with backward_pre_hooks

Ah this is probably a bug.

import torch
import torch.nn as nn

a = torch.ones(2, requires_grad=True)

model = nn.Linear(2, 2)

def fn(module, grad_output):
    return (grad_output[0] * 0,)

model.register_full_backward_pre_hook(fn, prepend=False)

out = model(a)
out.sum().backward()
print(a.grad) # should be 0, but its not