Torch conv1d for two 2D matrices

Dear All,

Im working on a simulation algorithm where the linear algebra is handled by pytorch. One step in the algorithm is to do a 1d convolution of two vectors. This needs to happen many times and so it needs to be fast. I decided to try to speed things further by allowing batch processing of input. This means that I sometimes need to do a convolution of two matrices along the second dimension. I can get the 1d convolution to work with torch.conv1d, but I cannot seem to figure out how to do it for the matrix case. I made a small example with code that does the same but that relies on a double for-loop which is not vectorized and will thus slow things down and is not very elegant (does give the expected result).

My question: how would I get the 1d convolution of two matrices to work with torch.conv1d

Example

import torch

na = 2
nv = 3
nbatch = 4

a1d = torch.randn(na)
v1d = torch.randn(nv)

def convolve(a, v):
    if a.ndim == 1:
        # this is the 1D case 
        padding = v.shape[-1] - 1
        b =  torch.conv1d(
            input=a.view(1, 1, -1), weight=v.flip(0).view(1, 1, -1), padding=padding, stride=1
        ).squeeze()
        return b
    elif a.ndim == 2:
        # this is the 2D case that is ugly!
        nrows, vcols = v.shape
        acols = a.shape[1]
        expanded = a.view((nrows, acols, 1)) * v.view((nrows, 1, vcols))
        noutdim = max(vcols, acols) + 1

        b = torch.zeros((nrows, noutdim))
        for i in range(acols):
            for j in range(vcols):
                b[:, i+j] += expanded[:, -i, j]
        return b
    else:
        raise NotImplementedError
        

a2d = torch.cat([a1d[None,:], torch.randn((nbatch-1, na))])
v2d = torch.cat([v1d[None,:], torch.randn((nbatch-1, nv))])

b1d = convolve(a1d, v1d)
b2d = convolve(a2d, v2d) 

print(b1d) 
tensor([ 0.6887,  0.9372, -1.6958, -0.0101])

print(b2d) # notice that the first row matches that of b1d as expected
tensor([[ 0.6887,  0.9372, -1.6958, -0.0101],
        [-0.0328,  0.9093, -0.6063,  0.4537],
        [-0.2817, -0.9321,  1.0376,  1.4543],
        [-2.8016, -1.6350,  1.2036,  0.3089]])

Hi Tomek!

Package your nbatch dimension as “channels” (in_features)
and then use conv1d()'s groups feature so that you apply each
weight vector in your batch of weight vectors separately to the
corresponding “channel” of your input.

Here is your code with the “groups” version added on at the end:

import torch
print (torch.__version__)

torch.manual_seed (2021)

na = 2
nv = 3
nbatch = 4

a1d = torch.randn(na)
v1d = torch.randn(nv)

def convolve(a, v):
    if a.ndim == 1:
        # this is the 1D case 
        padding = v.shape[-1] - 1
        b =  torch.conv1d(
            input=a.view(1, 1, -1), weight=v.flip(0).view(1, 1, -1), padding=padding, stride=1
        ).squeeze()
        return b
    elif a.ndim == 2:
        # this is the 2D case that is ugly!
        nrows, vcols = v.shape
        acols = a.shape[1]
        expanded = a.view((nrows, acols, 1)) * v.view((nrows, 1, vcols))
        noutdim = max(vcols, acols) + 1

        b = torch.zeros((nrows, noutdim))
        for i in range(acols):
            for j in range(vcols):
                b[:, i+j] += expanded[:, -i, j]
        return b
    else:
        raise NotImplementedError
        

a2d = torch.cat([a1d[None,:], torch.randn((nbatch-1, na))])
v2d = torch.cat([v1d[None,:], torch.randn((nbatch-1, nv))])

b1d = convolve(a1d, v1d)
b2d = convolve(a2d, v2d) 

print(b1d)
# tensor([ 0.6887,  0.9372, -1.6958, -0.0101])

print(b2d) # notice that the first row matches that of b1d as expected
# tensor([[ 0.6887,  0.9372, -1.6958, -0.0101],
#         [-0.0328,  0.9093, -0.6063,  0.4537],
#         [-0.2817, -0.9321,  1.0376,  1.4543],
#         [-2.8016, -1.6350,  1.2036,  0.3089]])

a2d_batch_channel = a2d.unsqueeze (0)
print (a2d_batch_channel.shape)   # [batch_size, nChannel, length]

v2d_weight = v2d.flip (1).unsqueeze (1)
print (v2d_weight.shape)   # [nChannel, nChannel / groups, kernel_size]

padding = v2d.shape[-1] - 1

b2db = torch.conv1d (a2d_batch_channel, v2d_weight, padding = padding, groups = nbatch)

print(b2db)  # matches b2d

And here is its output:

1.9.0
tensor([-1.9703, -1.3872, -1.8193, -0.4445])
tensor([[-1.9703, -1.3872, -1.8193, -0.4445],
        [ 0.8280, -1.0800, -4.3766,  0.3490],
        [ 1.4539, -1.4477,  2.5715, -0.9555],
        [-0.9232,  4.5329, -0.3340, -1.1177]])
torch.Size([1, 4, 2])
torch.Size([4, 1, 3])
tensor([[[-1.9703, -1.3872, -1.8193, -0.4445],
         [ 0.8280, -1.0800, -4.3766,  0.3490],
         [ 1.4539, -1.4477,  2.5715, -0.9555],
         [-0.9232,  4.5329, -0.3340, -1.1177]]])

Best.

K. Frank

1 Like

Thank you very much!