From torch.autograd.gradcheck import zero_gradients

0

This error comes after the upgrade of my PyTorch version from 1.8 to 1.9.0.

When using this line:

from torch.autograd.gradcheck import zero_gradients, I get this error message: ImportError: cannot import name 'zero_gradients' from 'torch.autograd.gradcheck'

The command:

zero_gradients(im)

is used.

What is the new command equivalent in PyTorch 1.9.0?

Unfortunately it was removed as part of a refactor to the code. I’m not sure there’s an equivalent you can just import, but since its relatively simple, you might want to just replicate it in your own code.
This was the original function:

def zero_gradients(x):
    if isinstance(x, torch.Tensor):
        if x.grad is not None:
            x.grad.detach_()
            x.grad.zero_()
    elif isinstance(x, collections.abc.Iterable):
        for elem in x:
            zero_gradients(elem)