JIT Trace model that computes gradients

Hi, I just wanted to ask if it’s possible to compute gradients inside a traced (torch.jit.trace) model.

We want to use torchscript to manually train models in mobile devices by computing gradients and recomputing new parameters. I tried to create simple model like this.

class ReturnGradientsModel(torch.nn.Module):
    
    def forward(self, input, w, b):
        result = input * w + b
        L = (10 - result).sum()
        L.backward()
        
        w = w - w.grad * 1e-2
        b = b - b.grad * 1e-2
        
        return w, b

torch.jit.trace(ReturnGradientsModel(), (torch.rand(2,2, requires_grad=True), 
                                        torch.rand(2,2, requires_grad=True), 
                                         torch.rand(2,2, requires_grad=True)))

but it just returns the error

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

even though I explicitly passed tensors with grad enabled.

I found this issue : https://github.com/pytorch/pytorch/issues/15644 . But it’s for onnx, not sure if it applies to torch.jit.trace as well.

I’m grateful for any help.