Autograd inside torchscript from java (libtorch)

Hello, I am trying to use autograd to get gradients inside a torchscript file I am calling from java using libtorch but am running into an error “RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn”

My torchscript model:

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        # Require a gradient
        x.requires_grad_()
        y = x.pow(2).sum()
        print(f'y.requires_grad: {y.requires_grad}')
        # Check if x requires a gradient
        print(f'x.requires_grad: {x.requires_grad}')
        dy_dx = torch.autograd.grad([y], [x])[0]
        return dy_dx

Saving with:

model_scripted = torch.jit.script(MyModel())
model_scripted.save("model.pt")

This throws no errors and returns the expected values when I reload and run the model inside python with:

model_loaded = torch.jit.load("model.pt")
dy_dx = model_loaded(x)
print(dy_dx)

However this fails and throws an error when I load inside java and try to run the model:

Traceback of TorchScript, original code (most recent call last):
  File "/var/folders/g9/ykm_xvjx25s8bq5bkfjqy5240000gn/T/ipykernel_39052/954115032.py", line 12, in forward
        # Check if x requires a gradient
        print(f'x.requires_grad: {x.requires_grad}')
        dy_dx = torch.autograd.grad([y], [x])[0]
                ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
        return dy_dx
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Not sure why this is a problem when loading the torchscript into java but not a problem when loading into python.

I guess you are trying to deploy the model onto a mobile platform?
If so, I think gradient calculation is disabled by default as its main purpose is inference.

Hey @ptrblck thanks for the response. No, this is not to a mobile platform. Just trying to run the torchscript model inside a java application.

Ah OK, I’m not familiar with this workflow and don’t know what the Java env might set and if it would disable gradient calculation.

Does it do the same if you try to load it with the JavaCPP Presets for PyTorch?

Attempting to use the JavaCPP Presets I run into an issue with JNITorch:

Warning: Version of org.bytedeco:pytorch could not be found.
Warning: Version of org.bytedeco:openblas could not be found.
 Uncaught exception: exiting with status code 1
java.lang.UnsatisfiedLinkError: no jnitorch in java.library.path

If you’re trying this on Android, we’d first need to create builds for that platform.