What's the difference between .cuda() and .to(device)

What’s the difference between tensor.cuda() and tensor.to(0)?

I copy function CUDA_tensor_apply2 from ATen/cuda/CUDAApplyUtils.cuh and use it as a PyTorch extension.

When I run

import torch
import my_extension.run as run

x = torch.rand(3, 4)
y = x.cuda()
print(run(y))  # all is well
print(y)  # all is well
print(x)  # all is well

But if I run

import torch
import my_extension.run as run

x = torch.rand(3, 4)
y = x.to(0)
print(run(y))  # incorrect result
print(y)  # RuntimeError (700)
print(x)  # no problem

z = x.cuda()  # RuntimeError (700)

I got the error RuntimeError (700) : an illegal memory access was encountered at /pytorch/aten/src/THC/THCReduceAll.cuh:327

Why? Is this error related to at::cuda::current_device()?

1 Like

Both calls should just transfer the tensor to the cuda:0 device.
Do you have a small reproducible code snippet, so that we can have a look?

@ptrblck
Thank you for replying, and I am so sorry for my deviation from the above description.
I mean, my code can run when using .cuda(), but got error when using .to(1).

Here’s my code, consist of 3 files, my_extension.cpp, my_extension_kernel.cu and setup.py.
File: my_extension.cpp

// file my_extension.cpp
#include <torch/torch.h>
#include <cmath>

void run_cpu(at::Tensor& out, const at::Tensor& x);
void run_gpu(at::Tensor& out, const at::Tensor& x);

void run_cpu(at::Tensor& out, const at::Tensor& x) {
    out = x + 1;
}

#define CHECK_CONTIGUOUS(x) AT_ASSERT(x.is_contiguous(), #x " must be contiguous.")
#define CHECK_INPUT(x) CHECK_CONTIGUOUS(x)

at::Tensor run(const at::Tensor x) {
    CHECK_INPUT(x);
    at::Tensor out = at::empty(x.sizes(), x.options());
    if(x.type().is_cuda()) {
        run_gpu(out, x);
    } else {
        run_cpu(out, x);
    }
    return out;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("run", &run, "My Extension");
}

File: my_extension_kernel.cu

// file my_extension_kernel.cu
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <THC/generic/THCTensorCopy.h>
#include <ATen/LegacyTHFunctionsCUDA.h>
#include <cmath>


template <typename scalar1, typename scalar2, int step, typename Op>
inline bool CUDA_tensor_apply22(at::Tensor a,
                               at::Tensor b,
                               const Op op,
                               at::cuda::TensorArgType aType = at::cuda::TensorArgType::ReadWrite,
                               at::cuda::TensorArgType bType = at::cuda::TensorArgType::ReadOnly) {
    int64_t totalElements = a.numel();

    if (a.numel() == 0) { return true; }
    constexpr uint32_t AT_APPLY_THREADS_PER_BLOCK = 512;
    const dim3 block = dim3(AT_APPLY_THREADS_PER_BLOCK);
 
    dim3 grid;
    int64_t curDevice = at::cuda::current_device();    

    if (curDevice == -1) return false;
    if (!at::cuda::getApplyGrid<step>(totalElements, grid, curDevice)) {
      return false;
    }

#define HANDLE_CASE(TYPE, A, B)                                        \
  at::cuda::kernelPointwiseApply2<Op,                                  \
                        scalar1,                                       \
                        scalar2,                                       \
                        TYPE, A, B, step>                              \
   <<<grid, block, 0, at::cuda::getCurrentCUDAStream(curDevice)>>>(    \
       aInfo, bInfo, static_cast<TYPE>(totalElements), op);

#define HANDLE_B_CASE(TYPE, A, B) {         \
  switch (B) {                              \
    case 1:                                 \
      HANDLE_CASE(TYPE, A, 1);              \
      break;                                \
    case 2:                                 \
      HANDLE_CASE(TYPE, A, 2);              \
      break;                                \
    default:                                \
      HANDLE_CASE(TYPE, A, -1);             \
      break;                                \
  }                                         \
}

#define HANDLE_A_CASE(TYPE, A, B) {         \
  switch (A) {                              \
    case 1:                                 \
      HANDLE_B_CASE(TYPE, 1, B);            \
      break;                                \
    case 2:                                 \
      HANDLE_B_CASE(TYPE, 2, B);            \
      break;                                \
    default:                                \
      HANDLE_B_CASE(TYPE, -1, B);           \
      break;                                \
  }                                         \
}

    if (at::cuda::detail::canUse32BitIndexMath(a) &&
        at::cuda::detail::canUse32BitIndexMath(b)) {
      at::cuda::detail::TensorInfo<scalar1, unsigned int> aInfo =
        at::cuda::detail::getTensorInfo<scalar1, unsigned int>(a);

      at::cuda::detail::TensorInfo<scalar2, unsigned int> bInfo =
        at::cuda::detail::getTensorInfo<scalar2, unsigned int>(b);
      at::cuda::rearrangeDims(&aInfo, &bInfo);
      aInfo.collapseDims();
      bInfo.collapseDims();

      HANDLE_A_CASE(unsigned int, aInfo.dims, bInfo.dims);
    } else {
      at::cuda::detail::TensorInfo<scalar1, uint64_t> aInfo =
        at::cuda::detail::getTensorInfo<scalar1, uint64_t>(a);

      at::cuda::detail::TensorInfo<scalar2, uint64_t> bInfo =
        at::cuda::detail::getTensorInfo<scalar2, uint64_t>(b);
      at::cuda::rearrangeDims(&aInfo, &bInfo);
      aInfo.collapseDims();
      bInfo.collapseDims();

      if (aInfo.dims == 1 && bInfo.dims == 1) {
        HANDLE_CASE(uint64_t, 1, 1);
      } else {
        HANDLE_CASE(uint64_t, -1, -1);
      }
    }
#undef HANDLE_CASE
#undef HANDLE_B_CASE
#undef HANDLE_A_CASE

  return true;
}


template <typename scalar1, typename scalar2, typename Op>
inline bool CUDA_tensor_apply22(at::Tensor a,
                               at::Tensor b,
                               const Op op,
                               at::cuda::TensorArgType aType = at::cuda::TensorArgType::ReadWrite,
                               at::cuda::TensorArgType bType = at::cuda::TensorArgType::ReadOnly) {
  return CUDA_tensor_apply22<scalar1, scalar2, 1, Op>(a, b, op, aType, bType);
}


void run_gpu(at::Tensor& out, const at::Tensor& x) {
    CUDA_tensor_apply22<float, float>(
        x, out, [=] __device__(const float& val, float& res) {
            res = val + 1;
        }
    );
}

File: setup.py

# file setup.py
from setuptools import setup
from torch.utils.cpp_extension import CUDAExtension, BuildExtension


setup(
  name="my_extension",
  ext_modules=[
    CUDAExtension("my_extension", 
                  ["my_extension.cpp", "my_extension_kernel.cu"],
                  extra_compile_args=["-expt-extended-lambda"])
    ],
  cmdclass={"build_ext": BuildExtension}    
)

Then run python setup.py install.
And run the test as follow:

import torch
import my_extension

x = torch.rand(3, 4)
y = x.cuda()
print(my_extension.run(y))
print(y)

z = x.to(1)
print(my_extension.run(z))
print(z) 

I do some simple check. The function inline bool CUDA_tensor_apply22 in my_extension_kernel.cu returns true.

Could you try to get the current device from the passed tensors instead of

int64_t curDevice = at::cuda::current_device();  

I haven’t tested the code yet, but if I’m not mistaken, this would use the current device specified by a device guard.

@ptrblck Just now, I change that line to

int64_t curDevice = at::device_of(a).value().index();
int64_t bDevice = at::device_of(b).value().index();

and add

std::cout << curDevice << bDevice << std::endl;

before #define HANDLE_CASE.

Then run this test code,

x = torch.rand(3, 4)
y = x.cuda()
z = x.to(1)
r = my_extension(y) # correct result and print 0 0
k = my_extension(z)  # incorrect result and print 1 1
print(y)  # RuntimeError

PS: If I print y without running my_extension(z), that’s no problem. my_extension(z) always return a tensor which all elements are 0.

PS: my function CUDA_tensor_apply22 is modified from CUDA_tensor_apply2 in PyTorch source file include/ATen/cuda/CUDAApplyUtils.cuh.