Pytorch with CUDA/CPP: Initialize a new tensor is too slow or RuntimeError: CUDA error: invalid configuration argument

Hi, I have spent three days on this bug :cry: , and I can’t figure it out. desperately need your help.
I wrote C++/ CUDA code to satisfy the pytorch framework and met this problem. Briefly speaking, when my code successfully goes through the forward part, Initializing the new tensor is many slow. ( forward part need 0.0x second, while initialization needs 40 seconds or so, which largely slows down my whole training)
Strangely, when I change the BLOCK_SIZE, a constant for CUDA kernel function, from 32 to 128, (anyway from the small number to a big one), the initialization part will be much faster. However when I check the output tensor from the forward, this error is thrown

  File "benchmark.py", line 50, in <module>
    print(c[0,0,:2,:2])
  File "/home/haolin/.conda/envs/gcb/lib/python3.7/site-packages/torch/tensor.py", line 82, in __repr__
    return torch._tensor_str._str(self)
  File "/home/haolin/.conda/envs/gcb/lib/python3.7/site-packages/torch/_tensor_str.py", line 300, in _str
    tensor_str = _tensor_str(self, indent)
  File "/home/haolin/.conda/envs/gcb/lib/python3.7/site-packages/torch/_tensor_str.py", line 201, in _tensor_str
    formatter = _Formatter(get_summarized_data(self) if summarize else self)
  File "/home/haolin/.conda/envs/gcb/lib/python3.7/site-packages/torch/_tensor_str.py", line 79, in __init__
    tensor_view = tensor.reshape(-1)
RuntimeError: CUDA error: invalid configuration argument

Part of the benchmark.py as the test code is shown as below:

    B = 20
    C = 16
    W = 336
    H = 336

    x = list(range(0, W * H)) * B * C
    x = torch.tensor(x, dtype=torch.float64, requires_grad=True).cuda().reshape(B,C,W,H)

    forward_time = 0
    backward_time = 0
    for _ in tqdm(range(options.runs)):

        # if options.example == 'cuda':
        #     globalContrast.zero_grad()

        start = time.time()
        c = globalContrast(x)
        print(c.shape)
        print(c[0,0,:2,:2])
        elapsed_fw = time.time() - start
        forward_time += elapsed_fw
        
        # this line will be very slow if BLOCK_SIZE=32, while it will be much faster with BLOCK_SIZE=64 or higher
        x = torch.rand((B, C, W, H), requires_grad=True).cuda()

        # print(c.shape)        
        # with BLOCK_SIZE=64, this line will cause the error mentioned above
        print(c[0,0,:2,:2])   # just print sth out for a test

        # ... for backward

The forward part in global_contrast_kernel.cu


namespace global_contrast_kernel{

template <typename scalar_t> 
__global__ void forward(
    const torch::PackedTensorAccessor<scalar_t, 4, torch::RestrictPtrTraits, size_t> feature,
    torch::PackedTensorAccessor<scalar_t, 4, torch::RestrictPtrTraits, size_t> output
) {
    long col = threadIdx.x + blockIdx.x * blockDim.x;
    long row = threadIdx.y + blockIdx.y * blockDim.y;
    const long B = feature.size(0);
    const long C = feature.size(1);
    const long W = feature.size(2);
    const long H = feature.size(3);
    scalar_t dis = 0.0f;
    for (auto i=0 ; i<B ; i++){
        for (auto j=0 ; j<C ; j++){
            for (auto _w=0 ; _w<W ; _w++){
                for (auto _h=0 ; _h<H ; _h++){
                    scalar_t diff = feature[i][j][col][row] - feature[i][j][_w][_h];
                    dis += diff * diff;
                }
            }
        }
        output[i][0][col][row] = dis;
        dis = 0.0f;
    }

    __syncthreads();
}

}

torch::Tensor global_contrast_cuda_forward(
    const torch::Tensor& feature
) {

    cudaSetDevice(feature.get_device());
    cudaStream_t stream = at::cuda::getCurrentCUDAStream();

    const long B = feature.size(0);
    const long C = feature.size(1);
    const long W = feature.size(2);
    const long H = feature.size(3);

    // allocate output tensor
    auto output = torch::zeros({B, 1, W, H}, feature.options());

    dim3 blockSize(BLOCK_SIZE, BLOCK_SIZE);
    dim3 gridSize((W + blockSize.x - 1) / blockSize.x, 
        (H + blockSize.y - 1) / blockSize.y);

    AT_DISPATCH_FLOATING_TYPES(feature.type(), "global_contrast_cuda_forward", ([&]{
        global_contrast_kernel::forward <scalar_t><<< gridSize, blockSize, 0, stream>>>(
            feature.packed_accessor<scalar_t, 4, torch::RestrictPtrTraits, size_t>(),
            output.packed_accessor<scalar_t, 4, torch::RestrictPtrTraits, size_t>()
        );
    }));

    cudaDeviceSynchronize();
    return output;
}

Entire repo is released on my Github/Global_contrast_CUDA
Appreciate your help sincerely in advance.:mask:

  • The hardware has a limit on the total size of blocks, ie the number of threads run in parallel on a multiprocessor, of 1024. The product of the block sizes cannot exceed that.
  • BLOCK_SIZE isn’t used in the copying, your measurement seems seems to be an artifact of not cuda.synchronize()ing.
  • Your kernel is terribly inefficient with a quadratic in image size number of memory accesses. You should check out one of the matrix multiplication cuda tutorials, the block partitioning scheme applies here, too.)

Best regards

Thomas

Oh Thank you very much! I have noticed that the total threads for each block should be under 1024, and I found my methods to access global memory too many times which is terribly inefficient!
I have optimized my code to pleasant speed~
Appreciate you again!

Hi Thomas, I still have a problem on input.grad. The output is always None, do you know what the wrong is?

here is my class py file:

class GlobalContrastFunction(Function):

    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x) # save x for backward
        return global_contrast.forward(x)

    @staticmethod
    def backward(ctx, grad):
        x = ctx.saved_variables # load
        dx = global_contrast.backward(grad.contiguous(), *x) # get gradient 
        print(dx.data[0,0,:4,:4])  # print out shows that dx has correct value!
        return dx


class GlobalContrast(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return GlobalContrastFunction.apply(x)

while in the loop of main py file:

    globalContrast = GlobalContrast().cuda()
    x = torch.rand((B, C, W, H), requires_grad=True).cuda()

    for _ in tqdm(range(options.runs)):

        if options.example == 'cuda':
            globalContrast.zero_grad()

        c = globalContrast(x) # forward

        grad = torch.sum(c)
        grad.backward(retain_graph=True) # set retain_graph=True here to save gradients for each node

        print(c.grad) # Here is None!
        print(x.grad) # Here is also None!

But when you call the global_contrast.backward manually, you do get something non-zero?
Some other random thoughts on your code:

  • .data should not be needed for anything these days.
  • You would not get gradients for non-leaves (like c unless you do c.retain_grad() before the backward). retain_graph is for when you want to call .backward a second time, not for keeping intermediate gradients.
  • calling the variable grad seems - let’s say - unconventional, but I know it’s just an example.

Best regards

Thomas

Oh I get it, I convert x to nn.Parameter and x.grad has the grad:

x = torch.rand((B, C, W, H), requires_grad=True)
x = nn.Parameter(x) # x should be the parameter
...
c = globalContrast(x)
y = torch.sum(c)
print(x.grad) # then get sth instead of None!

and I found your suggestion of x.retain_grad() is better, thanks!