I am trying to ramp up on torch.fx. I have a forward function with torch.no_grad() (example below).
def forward(self, x):
tmp_val = … do something …
out = torch.mul(x, tmp_val)
I did symbolic trace then printed out the code.
traced = torch.fx.symbolic_trace(model)
Then I can’t see no_grad() in the traced code anymore. Is there a way to keep this in the traced model?