I need to use conv layers and transformer FFNs (2 linear layers with an activation function in between) in an interleave manner, just like those in a ConvNext block.
Assume N=num_images, C=num_channel, H=image_height, W=image_weight, in conv layers, I need the input & output to be of shape [N, C, H, W] to use nn.Conv2d APIs, while in the FFN layers, I need the input & output tensor to be of shape [N, H, W, C] to perform linear transformation in the channel dimension. This leads to redundant permutation of tensors (especially when stacking multiple such blocks), which I assume is slow.
The code would look like:
x = torch.randn(N, C, H, W)
x = conv2d(x, conv_kernel) # [N, C, H, W]
x = x.permute(0, 2, 3, 1) # [N, H, W, C]
x = ffn(x) # [N, H, W, C]
x = x.permute(0, 3, 1, 2) # [N, C, H, W]
I’ll repeat such block several times, so I need to permute tensor x several times. Permutation doesn’t move data, it just calculate a proper stride for each dimension, but I’m sure if the FFN layer is efficient enough when axis C is discountinuous in memory?
If not, another option is to use the channel_last memory format to make dimension C continuous before & after conv2d. But in this case, I don’t know how to convert a channel_last [B, C, H, W] tensor to normal contiguous [B, H, W, C] tensor to let the FFN works. I tried to view the channel_last [B, C, H, W] tensor directly, but it doesn’t work.