Trace Backward Pass

Hey guys, I would like to trace/ Build a Pytorch IR for a backward pass/backprop using Torch’s JIT, but for some reason everytime I try trace, it goes through the program 3 times. The first time, it outputs that the loss tensor have grad_fn. But the steps after that it shows that the loss_fn does not have grad_fn.

Toy example:

import torch

def simple(x):
    loss = x.sum()
    print(loss)
    loss.backward()
    print("hello")
    return loss

if __name__ == "__main__":
    example = torch.rand(1, 3, 224, 224, requires_grad=True)
    traced_script_loss = torch.jit.trace(simple,example)

error:
in backward allow_unreachable=True) # allow_unreachable flag
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Any guidance would be very appreicated :slight_smile:

3 Likes

I am having the same issue where tracing through backward() gives me the same error: RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn. The stack trace reports loss.backward() as being the issue, but loss requires a grad.

It looks like tracing backward() calls should be supported…