Fixed it. Key takeaways in fixing autograd problems like these:
- Use
torch.autograd.set_detect_anomaly(True)
to isolate the line with the problem. - Call
.clone()
on the variables on the right hand side of the update.
had to replace the line:
dx = 2.0 * torch.stack([torch.rfft(x_r[0], 2), torch.rfft(x_r[1], 2)], dim=0)
with:
dx = 2.0 * torch.stack([torch.rfft(x_r[0].clone(), 2), torch.rfft(x_r[1].clone(), 2)], dim=0)
Found the following threads / blogs useful: