Hi, I just wanted to ask if it’s possible to compute gradients inside a traced (torch.jit.trace) model.
We want to use torchscript to manually train models in mobile devices by computing gradients and recomputing new parameters. I tried to create simple model like this.
class ReturnGradientsModel(torch.nn.Module):
def forward(self, input, w, b):
result = input * w + b
L = (10 - result).sum()
L.backward()
w = w - w.grad * 1e-2
b = b - b.grad * 1e-2
return w, b
torch.jit.trace(ReturnGradientsModel(), (torch.rand(2,2, requires_grad=True),
torch.rand(2,2, requires_grad=True),
torch.rand(2,2, requires_grad=True)))
but it just returns the error
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
even though I explicitly passed tensors with grad enabled.
I found this issue : https://github.com/pytorch/pytorch/issues/15644 . But it’s for onnx, not sure if it applies to torch.jit.trace as well.
I’m grateful for any help.