I was wrong and a device guard is needed for custom CUDA extensions.
Here is the diff to make the example executable on a non-default device:
diff --git a/cuda/lltm_cuda.cpp b/cuda/lltm_cuda.cpp
index 2434776..62c9628 100644
--- a/cuda/lltm_cuda.cpp
+++ b/cuda/lltm_cuda.cpp
@@ -1,5 +1,5 @@
#include <torch/extension.h>
-
+#include <c10/cuda/CUDAGuard.h>
#include <vector>
// CUDA forward declarations
@@ -40,7 +40,7 @@ std::vector<torch::Tensor> lltm_forward(
CHECK_INPUT(bias);
CHECK_INPUT(old_h);
CHECK_INPUT(old_cell);
-
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
return lltm_cuda_forward(input, weights, bias, old_h, old_cell);
}
@@ -62,6 +62,7 @@ std::vector<torch::Tensor> lltm_backward(
CHECK_INPUT(X);
CHECK_INPUT(gate_weights);
CHECK_INPUT(weights);
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(weights));
return lltm_cuda_backward(
grad_h,