Cuda extension gradient does not match autograd results

Hi,

My environment is: ubuntu1604 docker container, and ananconda python3.6.9 and pytorch1.3.1. A simplified version of my code is as follows, it is an implementation of nn.CrossEntropyLoss done with cuda extension:

import fun_cpp
class CrossEntropyFunctionV2(torch.autograd.Function):
    @staticmethod
    def forward(ctx, logits, labels, ignore_index):
        losses = fun_cpp.lsr_forward(logits, labels, ignore_index)

        ctx.variables = logits, labels, ignore_index
        return losses

    @staticmethod
    def backward(ctx, grad_output):
        logits, labels, ignore_index = ctx.variables

        grad = fun_cpp.lsr_backward(grad_output, logits, labels, ignore_index)
        print('grad2: ', grad[0, :, 0, 0])
        return grad, None, None


class CrossEntropyLossV2(nn.Module):

    def __init__(self, reduction='mean', ignore_index=-100):
        super(CrossEntropyLossV2, self).__init__()
        self.reduction = reduction
        self.ignore_index = ignore_index

    def forward(self, logits, labels):
        losses = CrossEntropyFunctionV2.apply(
                logits, labels, self.ignore_index)
        #  losses = losses[labels != self.ignore_index]
        if self.reduction == 'sum':
            losses = losses.sum()
        elif self.reduction == 'mean':
            losses = losses.mean()
        return losses


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        net = torchvision.models.resnet18(pretrained=False)
        self.conv1 = net.conv1
        self.bn1 = net.bn1
        self.maxpool = net.maxpool
        self.relu = net.relu
        self.layer1 = net.layer1
        self.layer2 = net.layer2
        self.layer3 = net.layer3
        self.layer4 = net.layer4
        self.out = nn.Conv2d(512, 3, 3, 1, 1)
    def forward(self, x):
        feat1 = self.conv1(x)
        feat2 = self.bn1(feat1)
        feat3 = self.relu(feat2)
        #  feat4 = self.maxpool(feat3)
        feat5 = self.layer1(feat3)
        feat6 = self.layer2(feat5)
        feat7 = self.layer3(feat6)
        feat8 = self.layer4(feat7)
        feat9 = self.out(feat8)

        feat7.retain_grad()
        feat7.register_hook(lambda grad: grad*1000)
        return feat9, feat7

net1 = Model()
net2 = Model()
from copy import deepcopy
net2.load_state_dict(deepcopy(net1.state_dict()))

#  criteria1 = CrossEntropyLossV1(reduction='mean', ignore_index=255)
criteria1 = CrossEntropyLossV2(reduction='mean', ignore_index=255)
criteria2 = nn.CrossEntropyLoss(reduction='mean', ignore_index=255)

net1.cuda()
net2.cuda()
net1.train()
net2.train()
criteria1.cuda()
criteria2.cuda()

optim1 = torch.optim.SGD(net1.parameters(), lr=1e-2)
optim2 = torch.optim.SGD(net2.parameters(), lr=1e-2)

bs = 32
for it in range(10):
    inten = torch.randn(bs, 3, 256, 256).cuda()
    lbs = torch.randint(0, 3, (bs, 16, 16)).cuda()
    #  net2.load_state_dict(deepcopy(net1.state_dict()))

    optim1.zero_grad()
    logits, feat = net1(inten.clone())
    loss1 = criteria1(logits, lbs.clone())
    loss1.backward()
    print('feat.grad1', feat.grad[0, :4, 0, 0])
    optim1.step()

    logits, feat = net2(inten.clone())
    optim2.zero_grad()
    loss2 = criteria2(logits, lbs.clone())
    loss2.backward()
    print('feat.grad2', feat.grad[0, :4, 0, 0])
    optim2.step()
    print(loss1.item() - loss2.item())
    print()

And the cuda implementation is like this:

#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

#include <THC/THC.h>
#include <THC/THCAtomics.cuh>
#include <THC/THCDeviceUtils.cuh>

#include <cuda.h>
#include <cuda_runtime.h>
#include <cfloat>

#include <iostream>

using std::cout;
using std::endl;

#define BLOCKSIZE 512


// kernel function for forward and backward
template<typename scalar_t>
__global__ void LSRLossForward(const int n_size,
                            const int dimsize, const int m_size,
                            const scalar_t *log_scores,
                            const int64_t *labels,
                            scalar_t *losses,
                            const int64_t ignore_index) {
    // shared memory
    __shared__ scalar_t sdata[BLOCKSIZE + 2];

    int tid = threadIdx.x;
    int bid = blockIdx.x;

    int samplesize = n_size * m_size;
    for (int i{bid}; i < samplesize; i+=gridDim.x) {
        int n_idx = i / m_size;
        int m_idx = i % m_size;
        int64_t lb = labels[i];
        if (lb == ignore_index) {
            if (tid == 0) losses[i] = 0;
            continue;
        } 

        int idx = n_idx * dimsize * m_size + lb * m_size + m_idx;
        if (tid == 0) losses[i] = -log_scores[idx];
    }
}


template<typename scalar_t>
__global__ void LSRLossBackward(const int n_size,
                            const int dimsize, const int m_size,
                            const scalar_t *grad,
                            scalar_t *grad_logits,
                            const scalar_t *scores,
                            const int64_t *labels,
                            const int64_t ignore_index) {
    int tid = threadIdx.x;
    int bid = blockIdx.x;

    int samplesize = n_size * m_size;
    for (int i{bid}; i < samplesize; i+=gridDim.x) {
        int n_idx = i / m_size;
        int m_idx = i % m_size;
        int64_t lb{labels[i]};
        for (int j{tid}; j < dimsize; j+=blockDim.x) {
            int idx = n_idx * dimsize * m_size + j * m_size + m_idx; 
            scalar_t gradval = 0; 
            if (lb != ignore_index) {
                gradval = scores[idx];
                if (j == lb) {
                    gradval -= 1.;
                }
            }
            grad_logits[idx] = gradval * grad[i];
        }
    }
}


// cuda forward and backward
at::Tensor LSR_forward_cuda(const at::Tensor &logits,
                                  const at::Tensor &labels,
                                  const int64_t ignore_index) {
    // CHECK type and shape
    AT_ASSERTM(logits.type().is_cuda(), "logits should be cuda");
    AT_ASSERTM(labels.type().is_cuda(), "labels should be cuda");

    const int n_size = logits.size(0);
    const int dimsize = logits.size(1);
    const int m_size = logits.numel() / (n_size * dimsize);
    const int samplesize = labels.numel();

    // allocate memory and cuda grid/block
    auto losses = torch::zeros_like(labels, logits.options());
    auto log_scores = torch::log_softmax(logits, 1);

    dim3 grid1(std::min(samplesize, (int)4096));
    dim3 block1(std::min(dimsize, (int)BLOCKSIZE));
    if (losses.numel() == 0) {
        THCudaCheck(cudaGetLastError());
        return losses;
    }

    // call kernel
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(losses.scalar_type(), "lsr forward", [&] {
        int shm_size = BLOCKSIZE * sizeof(scalar_t) * 2; 
        LSRLossForward<scalar_t><<<grid1, block1, shm_size, at::cuda::getCurrentCUDAStream()>>>(
            n_size, dimsize, m_size, 
            log_scores.contiguous().data<scalar_t>(), 
            labels.contiguous().data<int64_t>(), 
            losses.contiguous().data<scalar_t>(),
            ignore_index
        );
    });
    THCudaCheck(cudaGetLastError());
    return losses;
}


at::Tensor LSR_backward_cuda(const at::Tensor &grad,
                                  const at::Tensor &logits,
                                  const at::Tensor &labels,
                                  const int64_t ignore_index) {
    // CHECK type and shape
    AT_ASSERTM(grad.type().is_cuda(), "grad should be cuda");
    AT_ASSERTM(logits.type().is_cuda(), "logits should be cuda");
    AT_ASSERTM(labels.type().is_cuda(), "labels should be cuda");

    const int n_size = logits.size(0);
    const int dimsize = logits.size(1);
    const int m_size = logits.numel() / (n_size * dimsize);
    const int samplesize = labels.numel();

    // allocate memory and cuda grid/block
    auto grad_logits = torch::empty_like(logits);
    auto scores = torch::softmax(logits, 1);

    dim3 grid(std::min(samplesize, (int)4096));
    dim3 block(std::min(dimsize, (int)BLOCKSIZE));
    if (grad_logits.numel() == 0) {
        THCudaCheck(cudaGetLastError());
        return grad_logits;
    }

    // call kernel
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_logits.scalar_type(), "lsr backwrd", [&] {
        int shm_size = BLOCKSIZE * sizeof(scalar_t) * 2; 
        LSRLossBackward<scalar_t><<<grid, block, shm_size, at::cuda::getCurrentCUDAStream()>>>(
            n_size, dimsize, m_size, 
            grad.contiguous().data<scalar_t>(), 
            grad_logits.contiguous().data<scalar_t>(),
            scores.contiguous().data<scalar_t>(), 
            labels.contiguous().data<int64_t>(), 
            ignore_index
        );
    });
    THCudaCheck(cudaGetLastError());
    return grad_logits;
}

// python inferface
at::Tensor LSR_forward(const at::Tensor &logits,
                             const at::Tensor &labels,
                             const int64_t ignore_index
                             ) {
    if (!(logits.type().is_cuda() && labels.type().is_cuda())) {
        AT_ERROR("this LSR loss only supports gpu mode\n");
    } 
    at::DeviceGuard guard(logits.device());
    return LSR_forward_cuda(logits, labels, ignore_index);
}

at::Tensor LSR_backward(const at::Tensor &grad,
                                  const at::Tensor &logits,
                                  const at::Tensor &labels,
                                  const int64_t ignore_index) {
    // TODO: try AT_ASSERTM
    if (!(logits.type().is_cuda() && labels.type().is_cuda())) {
        AT_ERROR("this LSR loss only supports gpu mode\n");
    } 
    at::DeviceGuard guard(logits.device());
    return LSR_backward_cuda(grad, logits, labels, ignore_index);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("lsr_forward", &LSR_forward, "lsr forward");
    m.def("lsr_backward", &LSR_backward, "lsr backward");
}

My problem is that, the difference of gradient and loss between the cuda extension implementation and the nn.CrossEntropyLoss gets too big after about 10 iters. Though the gradient of the cuda implementation of the first iter is same with nn.CrossEntropyLoss. How could I solve this problem please?

Hi,

How large is the difference at the first iteration?

My expectation here is that they compute the same thing but because of floating point errors, the value they return is slightly different (order of 1e-6).
Then when you do gradient descent, you will take a step in that slightly different direction, leading to different gradient, etc. And the two models will diverge from each other.
This is quite expected, even though they get different weight values, they should converge properly to similar final losses (assuming your model is properly stable like most CNN).

Hi, Thanks for replying!!!

I am afraid the difference is beyond tolerance after about 7 iters. Here is the log message of the first 10 iters, the difference is quite smaller in the first iter, but the it grows rapidly in the following few iters:

grad2:  tensor([ 6.1093e-05,  3.2695e-05, -9.3787e-05], device='cuda:0')
feat.grad1 tensor([ 2.6212e-06,  1.7502e-06,  1.8138e-06, -2.5879e-06], device='cuda:0')
feat.grad2 tensor([ 2.6212e-06,  1.7503e-06,  1.8139e-06, -2.5879e-06], device='cuda:0')
1.1920928955078125e-07

grad2:  tensor([ 7.4433e-07, -8.7096e-05,  8.6352e-05], device='cuda:0')
feat.grad1 tensor([-1.7985e-06,  1.4837e-06,  1.2604e-06, -9.5811e-07], device='cuda:0')
feat.grad2 tensor([-1.7828e-06,  1.4885e-06,  1.2585e-06, -9.3544e-07], device='cuda:0')
-9.5367431640625e-07

grad2:  tensor([-2.3506e-05,  2.1586e-05,  1.9193e-06], device='cuda:0')
feat.grad1 tensor([ 8.6845e-07, -6.5791e-07,  3.0064e-06,  5.4995e-06], device='cuda:0')
feat.grad2 tensor([ 1.3321e-06, -1.5314e-07,  2.5235e-06,  4.8529e-06], device='cuda:0')
-0.00011467933654785156

grad2:  tensor([ 5.4759e-08,  2.2978e-06, -2.3525e-06], device='cuda:0')
feat.grad1 tensor([-3.3766e-06, -2.7148e-06,  2.5728e-06, -3.7042e-06], device='cuda:0')
feat.grad2 tensor([-2.0076e-06, -1.5157e-06,  3.2242e-06, -3.4447e-06], device='cuda:0')
-0.013294696807861328

grad2:  tensor([ 5.9485e-07, -5.9485e-07,  9.8629e-14], device='cuda:0')
feat.grad1 tensor([-5.6624e-06, -8.5295e-08,  5.8959e-07, -1.6566e-06], device='cuda:0')
feat.grad2 tensor([-3.1079e-06,  3.5109e-06,  2.4890e-06, -2.3740e-06], device='cuda:0')
-0.04217672348022461

grad2:  tensor([ 1.2207e-04, -1.2207e-04,  1.4023e-10], device='cuda:0')
feat.grad1 tensor([-2.6114e-06, -2.7196e-06,  2.5282e-06, -1.0882e-06], device='cuda:0')
feat.grad2 tensor([-3.4735e-06,  2.6183e-06,  3.2230e-06, -1.0301e-06], device='cuda:0')
-0.02991771697998047

grad2:  tensor([-1.2201e-04,  9.0557e-05,  3.1458e-05], device='cuda:0')
feat.grad1 tensor([ 7.7824e-07,  1.6687e-06, -1.0701e-06,  1.8441e-06], device='cuda:0')
feat.grad2 tensor([ 7.6660e-07, -5.4422e-08, -3.0931e-06, -1.6282e-06], device='cuda:0')
0.05857133865356445

grad2:  tensor([-2.1591e-05,  1.8860e-05,  2.7316e-06], device='cuda:0')
feat.grad1 tensor([-1.5478e-06, -1.0981e-06,  2.9503e-06, -7.7435e-07], device='cuda:0')
feat.grad2 tensor([ 1.4916e-06, -1.2717e-06, -1.8291e-07,  4.2905e-06], device='cuda:0')
-0.11451804637908936

grad2:  tensor([ 5.3404e-10,  1.0895e-04, -1.0895e-04], device='cuda:0')
feat.grad1 tensor([-2.5790e-06, -6.7249e-08, -2.2561e-06, -2.0569e-06], device='cuda:0')
feat.grad2 tensor([-2.1519e-06,  8.4891e-07,  8.0614e-07, -1.6089e-06], device='cuda:0')
0.4443323612213135

grad2:  tensor([ 7.9950e-07, -1.2178e-04,  1.2098e-04], device='cuda:0')
feat.grad1 tensor([ 3.2793e-06,  5.9981e-07, -4.5706e-07, -3.7950e-06], device='cuda:0')
feat.grad2 tensor([-1.7417e-06, -1.5725e-06,  1.4380e-07, -1.4676e-06], device='cuda:0')
0.5565323829650879

Also, by running cuda-memcheck, I got many errors, but I found no explicit errors after looking very carefully the logic of my code. Would you please tell me what is my problem ?

Hi,

So yes this is expected. The very small difference from each iteration is increased by the gradient descent and the two version will converge to different set of weights.
You can see this as using different random seeds for your initialization: you get a different set of weights but the final performance will be very close.

Also, by running cuda-memcheck , I got many errors, but I found no explicit errors after looking very carefully the logic of my code. Would you please tell me what is my problem ?

It is hard to say. I am by no mean a cuda expert. What kind of errors do you see? Do they point to your code directly?

But the difference of the loss is enlarged to be as large as 0.5 after merely 10 iters.

Actually, this implementation is a simplified version of my code. In my original code, the loss can be enlarged to more than 1 in merely 2 iters. Besides, I noticed that, when I run both autograd and cuda implementations, the cuda implementation would have have a great bias, but if I run only the cuda implementations, the loss value and gradient is quite near the output of autograd implementation. What is likely to be the problem please ?

As for the output fo cuda-memcheck, each piece of the error messages looks similar, like this:

========= Program hit cudaErrorCudartUnloading (error 4) due to "driver shutting down" on CUDA API call to cudaFree. 
=========     Saved host backtrace up to driver entry point at error
=========     Host Frame:/usr/lib/x86_64-linux-gnu/libcuda.so.1 [0x391b13]
=========     Host Frame:/miniconda/envs/py36/lib/python3.6/site-packages/torch/lib/../../../.././libcublasLt.so.10 [0x2a91e6]
=========     Host Frame:/miniconda/envs/py36/lib/python3.6/site-packages/torch/lib/../../../../libcublas.so.10 (cublasDestroy_v2 + 0xed) [0xb991d]
=========     Host Frame:/miniconda/envs/py36/lib/python3.6/site-packages/torch/lib/libtorch.so (cudnnDestroy + 0x219) [0x70fbc99]
=========     Host Frame:/miniconda/envs/py36/lib/python3.6/site-packages/torch/lib/libtorch.so [0x3ec9bad]
=========     Host Frame:/lib/x86_64-linux-gnu/libc.so.6 [0x39ff8]
=========     Host Frame:/lib/x86_64-linux-gnu/libc.so.6 [0x3a045]
=========     Host Frame:/lib/x86_64-linux-gnu/libc.so.6 (__libc_start_main + 0xf7) [0x20837]
=========     Host Frame:python [0x1c3160]

Any suggestions please ?

cuda-memcheck is not detecting any issues in your code, but crashes instead.
Try to install the latest nightly binary, as recently an invalid handle destruction was fixed (which seems to cause this error).

Thanks, would you please give me some suggestions on why the cuda extension implementation of crossentropy is different from pytorch implementation by such a big margin within so few iterations?

If you want to double check your implementation, you can use torch.autograd.gradcheck() with your function and double precision inputs. See the doc for more details.

But I will restate what I said above, a 1-bit difference at any stage (due to floating point non-associativity for example) will be increased by gradient descent and will lead to completely different final weights (even after a small number of iterations if you have a fairly deep net).

Thanks for replying !!

Can autograd.gradcheck() detect memory leakage ? One of my problem is that the result is normal when I use only my cude implementation. But when I use both my cuda implementation and pytorch autograd to check the difference, in the cuda implementation branch, the backward gradient of layers that uses pytorch inner autograd becomes very large and finally gets nan after 5 iters, even though the gradient of my cuda implementation has identical backward gradient if the input logits and labels are same. Thus I suspect that there might be some memory leakage that changed the values of other layer parameters.

No I’m afraid you’re gonna have to use CUDA specific tools to detect this kind of issues. Pytorch has no insight into your kernel.