I actually got it fixed by add the cpu() in the end.
ave_grads.append(p.grad.abs().mean().cpu())
max_grads.append(p.grad.abs().max().cpu())
I actually got it fixed by add the cpu() in the end.
ave_grads.append(p.grad.abs().mean().cpu())
max_grads.append(p.grad.abs().max().cpu())