You should mark p.grad as static, you can verify that this is safe and the grads have the same address by printing p.grad.data_ptr() on each iter. This is generally safe across iterations.
You should mark p.grad as static, you can verify that this is safe and the grads have the same address by printing p.grad.data_ptr() on each iter. This is generally safe across iterations.