Hi,
I think the error is related to the input’s leaf state (although I am not exactly sure, this may help).
But you can try enable grads for input x, by using:
x = x.clone().detach().requires_grad_(True)
as the first line of forward method.
Bests