Hey guys, I would like to trace/ Build a Pytorch IR for a backward pass/backprop using Torch’s JIT, but for some reason everytime I try trace, it goes through the program 3 times. The first time, it outputs that the loss tensor have grad_fn. But the steps after that it shows that the loss_fn does not have grad_fn.
import torch def simple(x): loss = x.sum() print(loss) loss.backward() print("hello") return loss if __name__ == "__main__": example = torch.rand(1, 3, 224, 224, requires_grad=True) traced_script_loss = torch.jit.trace(simple,example)
in backward allow_unreachable=True) # allow_unreachable flag
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
Any guidance would be very appreicated