I’m thinking about upgrading some of my customized PyTorch operations to support {N,H,W,C} format. However, I’m still confused about using channel-last-format tensors in PyTorch.
First, what does .contiguous(memory_format=torch.channels_last)
do?
a=torch.randn((1, 3, 128, 128)).contiguous(memory_format=torch.channels_last)
b=torch.randn((1, 128, 128, 3))
c=torch.randn((1, 128, 128, 3)).contiguous(memory_format=torch.channels_last)
d=torch.randn((1, 128, 128, 3)).permute([0,3,1,2]).contiguous(memory_format=torch.channels_last)
Which tensor is truly stored with {N,H,W,C} (e.g., {1, 128, 128, 3}) in memory? a, b, c, or d? Btw when I inputed them to convolution layers, both b and c raised errors.
Second, how to tell which memory format a tensor is using in CPP? In the documentation, the sample codes of creating customized operations with CPP are:
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT_TENSOR(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> lltm_forward(
torch::Tensor input,
torch::Tensor weights,
torch::Tensor bias,
torch::Tensor old_h,
torch::Tensor old_cell) {
CHECK_INPUT(input);
CHECK_INPUT(weights);
CHECK_INPUT(bias);
CHECK_INPUT(old_h);
CHECK_INPUT(old_cell);
return lltm_cuda_forward(input, weights, bias, old_h, old_cell);
}
I supposed the x.is_contiguous()
called by CHECK_CONTIGUOUS(x)
confirms x
is {N,C,H,W} but how do I know if it is {N,H,W,C}? If x.is_contiguous()
returns false
, is the tensor definitely {N,H,W,C}?
torch::Tensor my_op(const torch::Tensor &input,
const string &data_format) {
if (data_format == "NCHW") {
CHECK_INPUT_TENSOR(input)
}
else if (data_format == "NHWC") {
// How do I make sure that the tensor is NHWC?
}
else
throw std::invalid_argument("Error: Input data_format should be \"NCHW\" or \"NHWC\".");
torch::Tensor output = torch::zeros_like(input);
if (data_format == "NCHW") {
// call cuda function for NCHW input
}
else if (data_format == "NHWC") {
// call cuda function for NHWC input
}
return output;