How to perform the following 1D convolution?

Hey all,
I have a tensor t with shape (b,c,n,m) where b is the batch size, c is the number of channels, n is the sequence length (number of tokens) and m a number of parallel representations of the data (similar to the different heads in the transformer).
I want to perform a 1D conv over the channels and sequence length, such that each block would have its own convolution layer. pseudo-code:

t = torch.rand([b,c,n,m])
convs = [conv1d(c,c,1) for _ in range(m)]
for i in range(m)
   output[:,:,:,i] = convs[i](input[:,:,:,i])

I’m pretty sure that for loop isn’t the way to go. How can I perform this computation efficiently?

Thanks!

Hi Omer!

You can .reshape() your input tensor and use the groups constructor
argument of Conv1d:

>>> import torch
>>> torch.__version__
'1.9.0'
>>> _ = torch.manual_seed (2021)
>>> b = 2
>>> c = 3
>>> n = 5
>>> m = 2
>>> t = torch.rand ([b, c, n, m])
>>> u = t.permute (0, 3, 1, 2).reshape (b, m * c, n)
>>> conv = torch.nn.Conv1d (m * c, m * c, 1, groups = m)
>>> conv (u)
tensor([[[-0.0624, -0.3195, -0.3976, -0.0028, -0.3123],
         [ 0.5862,  0.6727,  0.9059,  0.3821,  0.9464],
         [-0.9408, -1.1464, -0.7950, -0.8718, -0.6577],
         [-0.4530, -0.5568, -0.4779, -0.8847, -0.5256],
         [-0.5623, -0.7672, -0.3805, -0.7634, -0.5598],
         [-0.6248, -0.6711, -0.2747, -0.6640, -0.0652]],

        [[-0.3047, -0.4761, -0.3934, -0.1065, -0.0300],
         [ 0.6465,  0.8431,  0.7757,  0.8131,  0.4114],
         [-1.0906, -0.9153, -0.9905, -0.7797, -0.7686],
         [-0.8211, -0.4375, -1.0047, -0.2747, -0.2525],
         [-1.0808, -0.6798, -0.8439, -0.2883, -0.5769],
         [-0.2513, -0.5921, -0.3546, -0.4071, -0.3656]]],
       grad_fn=<SqueezeBackward1>)
>>> t
tensor([[[[0.1304, 0.5134],
          [0.7426, 0.7159],
          [0.5705, 0.1653],
          [0.0443, 0.9628],
          [0.2943, 0.0992]],

         [[0.8096, 0.0169],
          [0.8222, 0.1242],
          [0.7489, 0.3608],
          [0.5131, 0.2959],
          [0.7834, 0.7405]],

         [[0.8050, 0.3036],
          [0.9942, 0.5025],
          [0.3734, 0.0413],
          [0.8387, 0.0604],
          [0.1773, 0.3301]]],


        [[[0.6857, 0.6960],
          [0.8303, 0.5216],
          [0.7438, 0.8290],
          [0.0219, 0.0813],
          [0.0172, 0.1464]],

         [[0.7492, 0.9450],
          [0.6737, 0.1135],
          [0.7421, 0.7810],
          [0.9446, 0.0451],
          [0.4282, 0.2427]],

         [[0.9363, 0.7784],
          [0.5605, 0.5312],
          [0.7132, 0.1075],
          [0.4496, 0.1255],
          [0.6784, 0.6550]]]])
>>> u
tensor([[[0.1304, 0.7426, 0.5705, 0.0443, 0.2943],
         [0.8096, 0.8222, 0.7489, 0.5131, 0.7834],
         [0.8050, 0.9942, 0.3734, 0.8387, 0.1773],
         [0.5134, 0.7159, 0.1653, 0.9628, 0.0992],
         [0.0169, 0.1242, 0.3608, 0.2959, 0.7405],
         [0.3036, 0.5025, 0.0413, 0.0604, 0.3301]],

        [[0.6857, 0.8303, 0.7438, 0.0219, 0.0172],
         [0.7492, 0.6737, 0.7421, 0.9446, 0.4282],
         [0.9363, 0.5605, 0.7132, 0.4496, 0.6784],
         [0.6960, 0.5216, 0.8290, 0.0813, 0.1464],
         [0.9450, 0.1135, 0.7810, 0.0451, 0.2427],
         [0.7784, 0.5312, 0.1075, 0.1255, 0.6550]]])

Best.

K. Frank

1 Like

Excellent! Thank you very much!