More efficient way to compute transposed convolution

TL;DR: is there a better way to compute a Conv1d (or any other N-dim convolution) on a (N, Lin, Cin) → (N, Lout, Cout) shaped input than doing pre- and post-transpose on the input/output tensors?

Full question:
I have an input tensor organized as follows: (batch, windows, features) or (N, L, C), that I need to pass through a set of convolution layers. Between the convolutions, however, I need to maintain the axis order (batch, windows, features) because after each layer there is a reshaping step which multiplies windows by reducing the number of features.
Think of it like of an upscaling algorithm, that takes L elements each with C channels, does a convolution to get L elements with 2C channels (padding=same) and then splits each element into two, resulting in 2L elements with C channels each.

Currently I deal with this by doing a transpose before and after each convolution, but I am positive this can be done more efficiently, without making unnecessary copy during the reshape:

    # Size: (N, L, 100)
    Transpose(-1, -2),  # put the tensor in (batch, features, windows) order (N, 100, L)
    Conv1d(
        in_channels=100,
        out_channels=200,
        kernel_size=3,
        stride=1, padding='same', padding_mode='replicate',
    ),  # Size: (N, 200, L)
    activator(),
    Transpose(-2, -1),  # put the tensor in (batch, windows, features) order (N, L, 200)
    LazyReshape((..., 'l', 200), (..., 'l*2', 100)),  # essentially reshape from (N, L, 200) into (N, 2*L, 100)
    Transpose(-1, -2),  # put the tensor in (batch, features, windows) order (N, 100, 2*L)
    Conv1d(
        in_channels=100,
        out_channels=200,
        kernel_size=3,
        stride=1, padding='same', padding_mode='replicate',
    ),  # Size: (N, 200, 2*L)
    ...

In the example, Transpose(dim1, dim2) is essentially def forward(x): x.transpose(dim1, dim2), with LazyReshape being a bit more complex, but generally doing an x.reshape(...) with an intelligently derived shape.

My biggest concern (except for the fact that this introduces unnecessary clutter to the network description) is that each “reshape” layer must use reshape and not just view, because the result of transpose is not contiguous, and reshaping the (batch, features, windows) form (or (N, 2C, L) into (N, C, 2L)) would not yield the desired result - in the upscaling example the element computed from 0 and 1 would land at index L, not 1. This would not be required if the conv. layer would just write its output in the proper order (I’m pretty sure it doesn’t care whether the input is transposed or not, although don’t quote me on that; it definitely can work in either case).

If there is a performance benefit behind keeping the output in the current (N, C, L) format that overweights the penalty of copy-reshaping then I’m cool with it, but I’d like to know about it and plan accordingly :slight_smile: .

FWIW ConvTranspose1d is not what I’m looking for, IIUC it’s essentially the same as Conv1d, just using a transposed convolution matrix, but still using the (N, C, L) shape for I/O.

P.S. I know that transpose(-1, -2) and transpose(-2, -1) are the same thing, I just use both to keep track which is which.

You could try to use channels-last convolutions by calling conv.to(memory_layout=torch.channels_last). Your input data is expected to be in the default channels-first format and you would have to call the to() operation on it, too.
However, since your actual memory layout is already in channels-last, you might be able to use tensor.as_strided_ specifying the shape and stride of your data, which should avoid the explicit transposes.

That was a good bet, but unfortunately IIUC channels_last is reserved for 4d tensors, i.e. 2d batched inputs and not 1d batched inputs (3d tensors), c.f. Tensor Attributes — PyTorch 2.0 documentation.

Tensor.transpose is already good enough to get a strided view, but the channels_last format appears to be ignored by Conv1d:

N = 32
Lin = 10
Cin = 5
Lout = Lin
Cout = 10

x = torch.rand(N, Lin, Cin)
conv = torch.nn.Conv1d(in_channels=Cin, out_channels=Cout, kernel_size=3, padding='same')

print('0', x.shape, x.stride())
x = x.transpose(-1,-2)
print('1', x.shape, x.stride())
x = conv(x)
print('2', x.shape, x.stride())
x = x.transpose(-2,-1)
print('3', x.shape, x.stride())

returns

0 torch.Size([32, 10, 5]) (50, 5, 1)
1 torch.Size([32, 5, 10]) (50, 1, 5)
2 torch.Size([32, 10, 10]) (100, 10, 1)
3 torch.Size([32, 10, 10]) (100, 1, 10)

while after adding

conv = torch.nn.Conv1d(in_channels=Cin, out_channels=Cout, kernel_size=3, padding='same')
conv = conv.to(memory_format=torch.channels_last)

the output remains the same:

0 torch.Size([32, 10, 5]) (50, 5, 1)
1 torch.Size([32, 5, 10]) (50, 1, 5)
2 torch.Size([32, 10, 10]) (100, 10, 1)
3 torch.Size([32, 10, 10]) (100, 1, 10)

An attempt to use x = x.to(memory_format=torch.channels_last) in turn results in:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Input In [65], in <cell line: 14>()
     12 x = x.transpose(-1,-2)
     13 print('1', x.shape, x.stride())
---> 14 x = x.to(memory_format=torch.channels_last)
     15 print('2', x.shape, x.stride())
     16 x = conv(x)

RuntimeError: required rank 4 tensor to use channels_last format

which appears to be in line with the memory format docs I linked above.

For a while I thought that changing order of activator and reshape layers would help, i.e. the doing the following would help:

print('1', x.shape, x.stride())
x = conv(x)
print('2', x.shape, x.stride())
x = x.transpose(-2,-1)
print('3', x.shape, x.stride())
x = relu(x)
print('4', x.shape, x.stride())

but it doesn’t appear to be the case, i.e. the activator seems to preserve the memory layout of the input:

1 torch.Size([32, 5, 10]) (50, 1, 5)
2 torch.Size([32, 10, 10]) (100, 10, 1)
3 torch.Size([32, 10, 10]) (100, 1, 10)
4 torch.Size([32, 10, 10]) (100, 1, 10)