Gradcheck for multiple samples

TL;DR Is there anyway to get gradcheck to work for multiple samples?

Hi All,

I’ve been using torch.autograd.gradcheck to validate a custom torch.autograd.Function object I’ve written, however, I’d like to test the accuracy of the backward method for multiple samples via finite-difference.

The torch.autograd.gradcheck method checks the backward formula analytically for a single sample, but doesn’t work for multiple sample (Or it tries to compute the jacobian across samples, which is something I don’t want).

I did try using torch.func.vmap, however, if one sample fails it’ll raise an error for that one sample, which kind of renders checking multiple samples invalid.

Is there any work around? I did check by using a pytorch equivalent with an automatically computed backward via autograd, and did an torch.allclose call to check element-wise, but that doesn’t get an ‘analytical’ results via finite-difference. Is there any solution to this?

Would using raise_exception=False work as I would assume it wouldn’t break your vmap code?

So, I’ve just tried this and I got an error.

def custom_func(x):
  return x**3 #but with custom backward

torch.func.vmap(torch.autograd.gradcheck(func=custom_func, inputs=(x), raise_exception=True), in_dims=(0))(x) #fails

Stacktrace:

Traceback (most recent call last):
  File "~/gradcheck_vmap.py", line 126, in <module>
    result = torch.func.vmap(torch.autograd.gradcheck(func=custom_func, inputs=(x), raise_exception=False), in_dims=(0))(x)
  File "~/myenv/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 434, in wrapped
    return _flat_vmap(
  File "~/myenv/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 39, in fn
    return f(*args, **kwargs)
  File "~/myenv/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 619, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
TypeError: 'bool' object is not callable

EDIT: The error is because I pass (x) to the gradcheck function.

You are already executing the gradcheck inside vmap which will directly return a bool value while a function is expected.
Besides that also tensors should be returned so more changes are needed.
Also, I’m using an internal attribute to unwrap the tensor inside vmap, which I assume could easily break, but this code seems to work for now:

import torch
from functools import partial

def custom_func(x):
  return x**3 #but with custom backward

def my_gradcheck(x, func):
    vx = torch._C._functorch.get_unwrapped(x)
    ret = torch.autograd.gradcheck(func=func, inputs=(vx), raise_exception=False)
    return torch.tensor(ret)

x = torch.randn(10, 10, dtype=torch.double, requires_grad=True)

partial_gradcheck = partial(my_gradcheck, func=custom_func)
vmap_gradcheck = torch.func.vmap(partial_gradcheck)
vmap_gradcheck(x)
# tensor([True, True, True, True, True, True, True, True, True, True])
1 Like

That’s a neat trick! Thanks for the example @ptrblck!

1 Like