Gradcheck fails for CUDA extension

Hello all,

I’m trying to implement a CUDA extension to do nearest neighborhood search. The forward result is tested to be right. When i do gradcheck, it raise the Jacobian mismatch error.I learned on the forum that change the eps can fix that, so i change the eps from 1e-1 to 1e-7 but it always raise the Jacobian mismatch error whatever the eps is.I wonder if there is something wrong with my cuda code, but i just can not find it :sob:.Please help me to see if it is a code preblem or a precision problem of the gradcheck :weary:

The error is like this:

Traceback (most recent call last):
  File "test_nnd.py", line 21, in <module>
    print(torch.autograd.gradcheck(dist.double(), (data1.double(), data2.double()), eps=1e-6))
  File "/home/Gilgamesh/anaconda3/envs/pytorch1.3/lib/python3.7/site-packages/torch/autograd/gradcheck.py", line 289, in gradcheck
    'numerical:%s\nanalytical:%s\n' % (i, j, n, a))
  File "/home/Gilgamesh/anaconda3/envs/pytorch1.3/lib/python3.7/site-packages/torch/autograd/gradcheck.py", line 227, in fail_test
    raise RuntimeError(msg)
RuntimeError: Jacobian mismatch for output 0 with respect to input 0,
numerical:tensor([[ 0.1937,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000, -0.3576,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.2347,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0447,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.5029,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.4172,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000, -0.7376,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.4731,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.6407,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.2049,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [-0.6109,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000, -0.5588,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.4247,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000, -0.9388,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000, -0.4619,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.2682,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.4619,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -1.0729,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         -0.1267,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.1173],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.9835,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.4992,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.7004,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.2980,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000, -0.1118],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.2682,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.0745,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.3725,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         -0.4619,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000, -0.3725]], dtype=torch.float64)
analytical:tensor([[ 0.7765,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  1.9917,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 1.6396,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.8206,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [-2.4007,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000, -1.8262,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -1.0693,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.4617],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  3.8508,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000, -0.4484],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -1.0813,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000, -1.4701]], dtype=torch.float64)

Here is my cuda code:

#include <stdio.h>
#include <vector>
#include <math.h>
#include "../cuda_utils.cuh" 
//[b,c,n], [b,c,m], float[b,n], int[b,n]
__global__ void NmDistanceKernel(int b, int c, int n, int m,
	const float *__restrict__ xyz1,const float *__restrict__ xyz2,
	float *__restrict__ result,int *__restrict__ result_i){
	
	int batchsize = blockIdx.x;
	int index = threadIdx.x;
	int stride = blockDim.x;
	xyz1 += batchsize*c*n;
	xyz2 += batchsize*c*m;
	result += batchsize*n;
	result_i += batchsize*n;
	for(int i=index;i<n;i+=stride)
	{
		float min_d = -1;
		int min_i;
		for(int j=0;j<m;j++)
		{
			float d = 0;
			for(int k=0;k<c;k++)
			{
				d += (xyz1[i + k*n] - xyz2[j + k*m]) * (xyz1[i + k*n] - xyz2[j + k*m]);
			}
			//float d = sqrt(s_d)
			if(min_d == -1 || d < min_d)
			{
				min_d = d;
				min_i = j;
				//printf("(%d)",j);
			}
		}
		//printf("(%d)",min_i);
		result[i] = min_d;
		result_i[i] = min_i;
		//printf("%d",min_i);
	}
}
//[b,c,n],[b,c,m],float[b,n] int[b,n]
void NmDistance(int b, int c, int n, int m, const float *xyz1_data, const float *xyz2_data,
	float *dist1_data, float *dist2_data,
	int *idx1_data, int *idx2_data){
	NmDistanceKernel<<<b,optimal_num_threads(n)>>>(b,c,n,m,xyz1_data,xyz2_data,dist1_data,idx1_data);
	NmDistanceKernel<<<b,optimal_num_threads(m)>>>(b,c,m,n,xyz2_data,xyz1_data,dist2_data,idx2_data);

	CUDA_CHECK_ERRORS();
	}
//[b,c,n],[b,c,m],float[b,n] int[b,n]
__global__ void NmDistanceGradKernel(int b,int c,int n,int m,
	const float *__restrict__ xyz1,const float *__restrict__ xyz2,
	const float *__restrict__ grad_dist1,const int *__restrict__ idx1,
	float *__restrict__ grad_xyz1,float *__restrict__ grad_xyz2){
	int batchsize = blockIdx.x;
	int index = threadIdx.x;
	int stride = blockDim.x;
	xyz1 += batchsize*c*n;
	xyz2 += batchsize*c*m;
	grad_xyz1 += batchsize*c*n;
	grad_xyz2 += batchsize*c*m;
	grad_dist1 += batchsize*n;
	idx1 += batchsize*n;
	for (int i=0;i<n;i+=stride){
		float g = grad_dist1[i]*2;
		int id = idx1[i];
			
		for (int k=0;k<c;k++)
		{
			atomicAdd(grad_xyz1 + i + k*n, g*(xyz1[i + k*n]-xyz2[id + k*m]));
			atomicAdd(grad_xyz2 + id + k*m, -(g*(xyz1[i + k*n]-xyz2[id + k*m])));
		}
		
	}
}

//[b,c,n],[b,c,m],float[b,n] int[b,n]
void NmDistanceGrad(int b,int c,int n,int m,
    const float *xyz1_data,const float *xyz2_data,
    float *gradxyz1_data,float *gradxyz2_data,
    const float *graddist1_data,const float *graddist2_data,
    const int *idx1_data,const int *idx2_data){

	NmDistanceGradKernel<<<b,optimal_num_threads(n)>>>(b,c,n,m,xyz1_data,xyz2_data,graddist1_data,idx1_data,gradxyz1_data,gradxyz2_data);
	NmDistanceGradKernel<<<b,optimal_num_threads(m)>>>(b,c,m,n,xyz2_data,xyz1_data,graddist2_data,idx2_data,gradxyz2_data,gradxyz1_data);
	
	CUDA_CHECK_ERRORS();
	
	}

Hi,

All the default values for eps and threshold are for double precision Tensors. Make sure everything is properly in double precision.

What is the formula that you use as the backward pass here?
Also make sure that you’re at a point where the function is continuously differentiable, otherwise, the gradcheck won’t work.

Hi,

Thank you for your reply! Here is my test code.

import torch
from nnd import NNDModule
data1 = torch.rand((2, 3, 5), requires_grad=True).cuda() #[b, c, n]
data2 = torch.rand((2, 3, 4), requires_grad=True).cuda() #[b, c, m]
dist = NNDModule()
print(torch.autograd.gradcheck(dist.double(), (data1.double(), data2.double()), eps=1e-6))

My function can only receive float precision input, so i use .double() to make it can receive double precision tensors, which i don’t know if that is a right way but the code has not raise the input error at least.Then, i input double precision random tensors to test, and the jacobin error raise :weary:.
The formula as the backward pass:

\[\begin{array}{l}
d{x_i} = 2*gradinput*({x_i} - {y_{xi}})\\
d{y_{xi}} =  - 2*gradinput*({x_i} - {y_{xi}})
\end{array}\]

the correspond code:

float g = grad_dist1[i]*2;
int id = idx1[i]; // the nearest neighbor index of xyz1[i] in xyz2
			
for (int k=0;k<c;k++) // traverse the channels
{
     // n is the number of item in xyz1
    // m is the number of item in xyz2
	atomicAdd(grad_xyz1 + i + k*n, g*(xyz1[i + k*n]-xyz2[id + k*m]));
	atomicAdd(grad_xyz2 + id + k*m, -(g*(xyz1[i + k*n]-xyz2[id + k*m])));
}

Also, i think the nearest neighbor search is not continuously differentiable because i think the max is not continuously differentiable,but the torch.max() can also pass the gradcheck , so i’m a little confused :dizzy_face:

You can slightly change your code to:

import torch
from nnd import NNDModule
data1 = torch.rand((2, 3, 5), requires_grad=True, dtype=torch.double, device="cuda") #[b, c, n]
data2 = torch.rand((2, 3, 4), requires_grad=True, dtype=torch.double, device="cuda") #[b, c, m]
dist = NNDModule()
dist.double()
print(torch.autograd.gradcheck(dist, (data1, data2)))

You shouldn’t need to change the eps.

but the torch.max() can also pass the gradcheck

It does as long as you’re far-enough from a point where the max value changes. If the finite difference actually hit a point where you get a different max, then the test will fail.

My function can only receive float precision input

You actually give double types above no? Do you convert them back to float before doing the actual computation?

The formula as the backward pass:

The forward pass actually does an outer product between x and y right? I am surprised not to see one here.

It raise the same jacobin error after i change the code like you suggest :sob:

I think the probaility of this happen is very low, so i think the diffierentiability is not the reason why the error raised. :thinking:

Yes, i convert them back to float at pytorch class encapsulation layer before doing the actual computation

The xyz1 and xyz2 are two point set each point has c channels.The forward pass is to find the nearest point for every point of xyz1 in xyz2 in the sence of euclidean distance. I think it is different from the outer product :joy:. What i do in the backward pass is to caculate the derivative of euclidean distance with respect to xyz1[i] and its nearest point in xyz2 which is determined during the forward pass.

It raise the same jacobin error after i change the code like you suggest :sob:

Yes the code change was just for clarity :slight_smile:

I think the probaility of this happen is very low, so i think the diffierentiability is not the reason why the error raised. :thinking:

In theory yes. In practice, if you generate hundreds of Tensors with values between 0-1. How probable is it that two of them are less than 0.5e-3 (half the eps used, meaning that the max will change and we will cross over a point of non-differentiability where the numerical gradient will return garbage) appart?

Yes, i convert them back to float at pytorch class encapsulation layer before doing the actual computation

That might be the issue. In my experience, single precision floats are hard to get to work with gradcheck.
You will need to raise the epsilon (more issues with non-differentiable points as mentioned above) and the tolerance quite a lot.
Can you add a simple float/double template to your function by any chance?

Otherwise, you won’t be able to use the gradcheck blindly and you’ll have to play with the eps/threshold and see when it passes, when it fails and try and get a feeling to know if it’s a numerical precision or formula issue.

Thank you for your patient answers so many times :yum:, I finally find the reason which is my CUDA code in backward pass has some bug.Change the int i=0 in the first for loop to int i=index, then it pass the gradcheck with epsilon=1e-3 and 1e-4, but fail in 1e-5 or less, which i think is the problem of precision.(The gradcheck is fail with epsilon range from 1e-2 to 1e-7 before i correct the code)

int batchsize = blockIdx.x;
int index = threadIdx.x;
int stride = blockDim.x;
.......
//Wrong: for (int i=0;i<n;i+=stride)
//Correct:
for (int i=index;i<n;i+=stride){
	........
			
	for (int k=0;k<c;k++)
	{
		atomicAdd(grad_xyz1 + i + k*n, g*(xyz1[i + k*n]-xyz2[id + k*m]));
		atomicAdd(grad_xyz2 + id + k*m, -(g*(xyz1[i + k*n]-xyz2[id + k*m])));
	}

}

Thank you for your advice ! :+1: . I will add a template in my code.

Awesome!

Happy that you found the issue.
Given these values, it sounds like the implementation is correct indeed.