Does torch:: nn:: Conv2d support tensors for NHWC layout in libtorch

torch::Tensor input = torch::randn({ 1, 3, 12, 12 }, torch::kF32);
input = input.permute({ 0, 2, 3, 1 });
input = input.to(torch::MemoryFormat::ChannelsLast);
how to set the torch::nn::Conv2d to fit the input for NHWC layout?

Call the same .to(...) method on the module itself.

@ptrblck
There is no corresponding interface: .to(torch::MemoryFormat::ChannelsLast) for torch::nn::Module

Oh, you are right. In this case, transform the weight directly:

#include <torch/torch.h>
#include <iostream>

int main() {
  torch::Tensor tensor = torch::randn({1, 3, 224, 224}).to(torch::MemoryFormat::ChannelsLast).to(torch::kCUDA);
  std::cout << tensor.strides() << ", " << tensor.sizes() << std::endl;

  auto conv = torch::nn::Conv2d(torch::nn::Conv2dOptions(3, 8, 3).stride(1).bias(false));
  conv->to(torch::kCUDA);
  {
    torch::NoGradGuard no_grad;
    conv->weight = conv->weight.to(torch::MemoryFormat::ChannelsLast);
  } 

  std::cout << conv->weight.strides() << ", " << conv->weight.sizes() << std::endl;
  auto out = conv->forward(tensor);
  std::cout << out.strides() << ", " << out.sizes() << std::endl;
}
nsys nvprof ./main
...
 Time (%)  Total Time (ns)  Instances  Avg (ns)  Med (ns)  Min (ns)  Max (ns)  StdDev (ns)                                                  Name                                                
 --------  ---------------  ---------  --------  --------  --------  --------  -----------  ----------------------------------------------------------------------------------------------------
     85.7            23104          1   23104.0   23104.0     23104     23104          0.0  void convolve_common_engine_float_NHWC<float, float, (int)128, (int)5, (int)5, (int)3, (int)3, (int…
     14.3             3841          1    3841.0    3841.0      3841      3841          0.0  void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl_nocast<at::n…

@ptrblck Thank you so much :smiley: