kyong
May 11, 2021, 8:06pm
1
I am trying to ramp up on torch.fx. I have a forward function with torch.no_grad() (example below).
def forward(self, x):
with torch.no_grad():
tmp_val = … do something …
out = torch.mul(x, tmp_val)
return out
I did symbolic trace then printed out the code.
traced = torch.fx.symbolic_trace(model)
print(traced.code)
Then I can’t see no_grad() in the traced code anymore. Is there a way to keep this in the traced model?
Hi @kyong ,
FX does not support context managers (such as no_grad
) in traced regions. Can you refactor the original code to something like:
class MyModel(torch.nn.Module):
def forward(self, x):
tmp_val = … do something …
out = torch.mul(x, tmp_val)
return out
class NoGradWrapper(torch.nn.Module):
def __init__(self):
self.mm = MyModel()
def forward(self, x):
with torch.no_grad():
return self.mm(x)
Then you can apply FX transforms to NoGradWrapper.mm
and still preserve the no-gradient behavior
kyong
May 28, 2021, 4:13am
3
Hi James,
Thank you for your comments! I actually ended up not going with torch.fx, but modified the deep net in the original source code I used. But it is great to know torch.fx does not preserve the context managers.