Error: ‘struct DLTensor’ has no member named ‘ctx’

I am trying to compile custom cpp extension but I am getting following error:

/workspace/cpp_build/cpp_functions.cpp:81:28: error: ‘struct DLTensor’ has no member named ‘ctx’
   81 |       dlMTensor->dl_tensor.ctx.device_id = device_id;

Code related to this is:

#include <torch/extension.h>

#include <vector>
#include <iostream>
#include <sstream>

#include <ATen/NativeFunctions.h>
#include <ATen/Config.h>
#include <ATen/DLConvertor.h>
#include <ATen/Functions.h>

at::Tensor backward_weight(
    c10::ArrayRef<long int> weight_size,
    const at::Tensor& grad_output,
    const at::Tensor& input,
    c10::ArrayRef<long int> padding,
    c10::ArrayRef<long int> stride,
    c10::ArrayRef<long int> dilation,
    int64_t groups,
    bool benchmark,
    bool deterministic) {

  return at::cudnn_convolution_backward_weight(
      weight_size,
      grad_output,
      input,
      padding,
      stride,
      dilation,
      groups,
      benchmark,
      deterministic);
}
at::Tensor backward_input(
    c10::ArrayRef<long int> input_size,
    const at::Tensor& grad_output,
    const at::Tensor& weight,
    c10::ArrayRef<long int> padding,
    c10::ArrayRef<long int> stride,
    c10::ArrayRef<long int> dilation,
    int64_t groups,
    bool benchmark,
    bool deterministic) {

  return at::cudnn_convolution_backward_input(
      input_size,
      grad_output,
      weight,
      padding,
      stride,
      dilation,
      groups,
      benchmark,
      deterministic);
}

// From pytorch/torch/csrc/Module.cpp
void DLPack_Capsule_Destructor(PyObject* data) {
  HANDLE_TH_ERRORS
    DLManagedTensor * dlMTensor = (DLManagedTensor *)PyCapsule_GetPointer(data, "dltensor");
  if (dlMTensor) {
    // the dlMTensor has not been consumed, call deleter ourselves
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
    dlMTensor->deleter(const_cast<DLManagedTensor*>(dlMTensor));
  } else {
    // the dlMTensor has been consumed
    // PyCapsule_GetPointer has set an error indicator
    PyErr_Clear();
  }
  END_HANDLE_TH_ERRORS_RET()
}

namespace py = pybind11;
using namespace pybind11::literals;

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("backward", &backward_weight, "Conv backward_weight cudnn");
  m.def("backward_input", &backward_input, "Conv backward_input cudnn");
  m.def("to_dlpack_with_device_id", [](const at::Tensor& data, int64_t device_id) {
      DLManagedTensor* dlMTensor = at::toDLPack(data);
      
      dlMTensor->dl_tensor.ctx.device_id = device_id;
      auto capsule = py::capsule(dlMTensor, "dltensor", DLPack_Capsule_Destructor);
      return capsule;
  }, "Specify device_id in dlpack, for cupy to copy to right GPU");
}

I want to build this cpp extension for latest cudatoolkit version latest pytorch version.
Can anyone guide me on how I can modify dlMTensor->dl_tensor.ctx.device_id = device_id; for latest pytorch version?

I don’t know which DLPack version you are using, but the missing ctx attribute might have been caused by this PR which replaced DLContext ctx with DLDevice device.

1 Like

Thank you @ptrblck for replying. I was able to solve it too. But I wanted to ask one more thing. In the at::Tensor backward_weight function it returns cudnn_convolution_backward_weight which is deprecated in the latest version of pytorch. I checked and I found that pytorch has now renamed the same function and is present in native library with name of cudnn_convolution_backward_weight. Can you please guide me how I can use this function with the latest pytorch version.

I looked at some relevant tutorials from REGISTERING A DISPATCHED OPERATOR IN C++ and EXTENDING TORCHSCRIPT WITH CUSTOM C++ OPERATORS.

Similar issue but with no solution till now. How to include native functions from ATen in cpp/CUDA extension?

In Python you could use:

torch.nn.grad.conv2d_weight
torch.nn.grad.conv2d_input

and I think in libtorch

std::tuple<Tensor, Tensor, Tensor> convolution_backward(
    const Tensor& grad_output_, const Tensor& input_, const Tensor& weight_,
    const at::OptionalIntArrayRef bias_sizes_opt,
    IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool transposed, IntArrayRef output_padding,
    int64_t groups, std::array<bool, 3> output_mask);

should work which will return:

std::make_tuple(backend_grad_input, backend_grad_weight, backend_grad_bias);
1 Like