JIT does not support parameter.requires_grad?

Hi,
is it a known limitation that jit.trace will ignore temporary requires_grad = False?

Here is an example:

# EXAMPLE 1
import torch
from torch import nn, jit
from torch.optim import SGD


inputs = torch.tensor([2.0], device="cuda")
model = nn.Linear(1, 1, bias=False).to("cuda")

optimizer = SGD(model.parameters(), lr=1e-1)


class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = model

    def forward(self, x):
        param = next(self.parameters())

        param.requires_grad = True
        x = self.model(x).mean()
        param.requires_grad = False
        return x


c = MyModule()
forward = jit.trace(c, (inputs,))
result = forward(inputs)

result.mean().backward()

optimizer.step()
optimizer.zero_grad()

print("It does not work fine!")
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

But when I switch the requires_grad flag:

# EXAMPLE 2
import torch
from torch import nn, jit
from torch.optim import SGD


inputs = torch.tensor([2.0], device="cuda")
model = nn.Linear(1, 1, bias=False).to("cuda")

optimizer = SGD(model.parameters(), lr=1e-1)


class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = model

    def forward(self, x):
        param = next(self.parameters())

        param.requires_grad = False # True --> False
        x = self.model(x).mean()
        param.requires_grad = True # False --> True
        return x


c = MyModule()
forward = jit.trace(c, (inputs,))
result = forward(inputs)

result.mean().backward()

optimizer.step()
optimizer.zero_grad()

print("It does work fine!")
It works fine!

However, when I run it without jit, it runs like expected which is: example 1 runs fine and example 2 fails with an error!

Additional short question: is there something like torch.jit.ignore for tracing?

For everyone wondering: requires_grad is not supposed to work. trace only tracks tensor operations, not attributes. See [JIT] jit.trace does not support parameter.requires_grad? · Issue #53515 · pytorch/pytorch · GitHub

Anyone know a solution? This is something that is supported in tensorflow, but not in PyTorch it seems. Makes optimizing and training my model via jit very hard even though it would profit from jit.

Anyone know a solution?

script or trace deeper modules or functions - e.g. compiling(tracing) “infrastructure” MyModule makes no sense. Use @jit.ignore to break recursive compilation.

you can combine scripting & tracing, if you trace functions, e.g.:

def _beta_rsample_3d(a,b):
	return D.Beta(a,b).rsample()
beta_rsample_3d = jit.trace(_beta_rsample_3d, (torch.ones(1,1,1), torch.ones(1,1,1)), check_trace=False)
1 Like

Thank you for answering! :slight_smile:

Can you explain why it does not make sense? Obviously here I use it only to showcase a minimal example, but in a more complex model, these kind of modules will exist. The solution you suggested is something I really want to avoid since it will clutter my model. I would not want to change my model to support scripting/tracing. If tracing would include tensor attributes I could just pass my whole forward pass to jit.trace which would properly separate optimization from model semantics. jit.script would be fine since it can be used at least with decorators, but unfortunately the missing support of torch.distributions does not allow jit.script.

btw: Do you know whether torch.jit.script will track tensor attributes (i.e. requires_grad) or is that something that the jit is generally not supposed to do and it is only meant for deployment?

JIT speedups mostly come from tensor operations - like fused math operations and some optimizations specific to tensors, ops or layers. Classes like above are closer to non-computational wrappers - similar to training loop code, there is not much to optimize there.

It is read-only (and used as meta-information in compilation). Instead, you can detach() tensors, and for parameters - you’d have to set it from outside.

Dunno, I find jit.trace too dangerous / limiting to use on module trees. As for torch.distributions, I mostly stopped using these wrappers when coding for performance.

1 Like

Edit: Sorry, I didn’t see you already reported it as an issue and got more input there, too.
Best regards

Thomas

1 Like

Hey, thank you for updating!

I see what you mean by my example not making sense. I just wanted to showcase the toggling of the requires_grad flag.

I am wondering why .detach() is supported, but not requires_grad. To me both fullfill orthogonal ways to influence gradient backpropagation. I guess it is just a practical limitation about how tracing works?

Why do you regard jit.trace as being dangerous/limiting? I am learning a lot here. Btw if you have some ressource where I can read up on the practical use of JIT I would appreciate you pointing me there!

param.requires_grad_(b) may also work, but frankly this may be not an anticipated use case. Intuitively, you own tensors created in functions, and can toggle requires_grad freely, but parameters are owned by a module, so it is not clear if toggling works cleanly from inside.

as you said above, “trace only tracks tensor operations”, this limits passable python code pretty severely. “dangerous” is associated hardcoding of non-tensor values as constants, though this issues warnings.

Contrived example that fails (late) with jit.trace:

k = x.size(0)
ls = [x[i] for i in range(k)]
1 Like

Yes. So while requires_grad_ (always use the method!) could probably be supported, the list size here inherently is invisible to PyTorch in scripting.
The summary I always have is (I think it’s also in ch 15 of our book):

  • In tracing you can do anything you want, but the JIT won’t (and can’t) try to understand it all.
  • In scripting you can only do what the understands but the JIT will do it all.

This is inherent and to me it means that the second part would want to be extended in scope.

Best regards

Thomas

1 Like

Thank you again for the explanation @tom @googlebot !

Something like this would solve my problem just fine:

c = MyModule()
# Tell tracer to also track some parameters
forward = jit.trace(c, (inputs,), track=c.parameters())

To me tracing looks like the more elegant solution if the requirements for tracing are met, which they are for me except that I need to record parameters requires_grad.

I guess currently there is nothing much I can do except either change my code to support scripting or not use the jit.

Could you tell me why to use the method? In PyTorch docs requires_grad is set directly all the time.

First, let me admit that there is a modicum of taste here.

  1. I sometimes misspell requires.
  2. I find myself thinking of “set requires grad to true” as an operation I apply to a tensor rather than the change of a data member. In other words, setting the gradient requirement is more “typical use of an object-modifying-method” than "typical use of =" to me.

Best regards

Thomas

1 Like

I just find the same issue while using torch.jot.script in Pytorch 1.7.1. The easiest workaround is to just x.requires_grad_(True) to change the value.