Hey guys, I would like to trace/ Build a Pytorch IR for a backward pass/backprop using Torch’s JIT, but for some reason everytime I try trace, it goes through the program 3 times. The first time, it outputs that the loss tensor have grad_fn. But the steps after that it shows that the loss_fn does not have grad_fn.
Toy example:
import torch
def simple(x):
loss = x.sum()
print(loss)
loss.backward()
print("hello")
return loss
if __name__ == "__main__":
example = torch.rand(1, 3, 224, 224, requires_grad=True)
traced_script_loss = torch.jit.trace(simple,example)
error:
in backward allow_unreachable=True) # allow_unreachable flag
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
Any guidance would be very appreicated