torch::Tensor input = torch::randn({ 1, 3, 12, 12 }, torch::kF32);
input = input.permute({ 0, 2, 3, 1 });
input = input.to(torch::MemoryFormat::ChannelsLast);
how to set the torch::nn::Conv2d to fit the input for NHWC layout?
Call the same .to(...)
method on the module itself.
@ptrblck
There is no corresponding interface: .to(torch::MemoryFormat::ChannelsLast) for torch::nn::Module
Oh, you are right. In this case, transform the weight
directly:
#include <torch/torch.h>
#include <iostream>
int main() {
torch::Tensor tensor = torch::randn({1, 3, 224, 224}).to(torch::MemoryFormat::ChannelsLast).to(torch::kCUDA);
std::cout << tensor.strides() << ", " << tensor.sizes() << std::endl;
auto conv = torch::nn::Conv2d(torch::nn::Conv2dOptions(3, 8, 3).stride(1).bias(false));
conv->to(torch::kCUDA);
{
torch::NoGradGuard no_grad;
conv->weight = conv->weight.to(torch::MemoryFormat::ChannelsLast);
}
std::cout << conv->weight.strides() << ", " << conv->weight.sizes() << std::endl;
auto out = conv->forward(tensor);
std::cout << out.strides() << ", " << out.sizes() << std::endl;
}
nsys nvprof ./main
...
Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name
-------- --------------- --------- -------- -------- -------- -------- ----------- ----------------------------------------------------------------------------------------------------
85.7 23104 1 23104.0 23104.0 23104 23104 0.0 void convolve_common_engine_float_NHWC<float, float, (int)128, (int)5, (int)5, (int)3, (int)3, (int…
14.3 3841 1 3841.0 3841.0 3841 3841 0.0 void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl_nocast<at::n…
@ptrblck Thank you so much