Torch.fx and no_grad()

I am trying to ramp up on torch.fx. I have a forward function with torch.no_grad() (example below).

def forward(self, x):
with torch.no_grad():
tmp_val = … do something …

out = torch.mul(x, tmp_val)
return out

I did symbolic trace then printed out the code.

traced = torch.fx.symbolic_trace(model)
print(traced.code)

Then I can’t see no_grad() in the traced code anymore. Is there a way to keep this in the traced model?

Hi @kyong,

FX does not support context managers (such as no_grad) in traced regions. Can you refactor the original code to something like:

class MyModel(torch.nn.Module):
  def forward(self, x):
    tmp_val = … do something …
    out = torch.mul(x, tmp_val)
    return out

class NoGradWrapper(torch.nn.Module):
  def __init__(self):
    self.mm = MyModel()

  def forward(self, x):
    with torch.no_grad():
      return self.mm(x)

Then you can apply FX transforms to NoGradWrapper.mm and still preserve the no-gradient behavior

Hi James,

Thank you for your comments! :slight_smile: I actually ended up not going with torch.fx, but modified the deep net in the original source code I used. But it is great to know torch.fx does not preserve the context managers.