Swap axes in pytorch?

Hi, in tensorflow, we have data_format option in tf.nn.conv2d which could specify the data format as NHWC or NCHW.

Is there equivalent operation in pytorch?

If not, should we convert Variable to numpy.array, use np.swapaxes and convert it back into Variable?
And under such circumstances, will the gradient tracked properly?

6 Likes

No, we only support NCHW format. You can use .permute to swap the axis.

22 Likes

Thanks, I have just checked the Docs, but it seems that I just miss it…

@Veril transpose only applies to 2 axis, while permute can be applied to all the axes at the same time.
For example

a = torch.rand(1,2,3,4)
print(a.transpose(0,3).transpose(1,2).size())
print(a.permute(3,2,1,0).size())

BTW, permute internally calls transpose a number of times

23 Likes

Indeed, it can be a shortcut to use

tensor.transpose_(0, 1)

instead of

tensor = tensor.transpose(0, 1)

But note that the difference in performance is not significant, as transpose does not copy memory nor allocate new memory, and only swaps the strides.

4 Likes

Awesome method! Why not combine permute and transpose or make transpose inaccessible to user since it’s used internally by permute as mentioned by fmassa.

1 Like