Is my understanding of kernel stride correct in this example?

Hello,

I have a matrix:

A1 A2 A3 A4 A5 ... 
B1 B2 B3 B4 B5 ... 
C1 C2 C3 C4 C5 ... 
D1 D2 D3 D4 D5 ... 
.. .. .. .. ..

I would like to apply a 2d conv kernel of size (2,3) to this matrix. The kernel should jump the rows with steps of 2, i.e. in the first stride it should encompass row A and B, in the second stride step it should consider row C and D, just like below (the kernel is represented by [.] and the matrix after >> shows how the kernel moves from left to right; not necessary, but I had fun typing these matrices :smiley: ;)):

First “Stride Step”:

[A1 A2 A3] A4 A5 ...  >>  A1 [A2 A3 A4] A5 ...    
[B1 B2 B3] B4 B5 ...  >>  B1 [B2 B3 B4] B5 ...    
 C1 C2 C3  C4 C5 ...  >>  C1  C2 C3 C4  C5 ...    
 D1 D2 D3  D4 D5 ...  >>  D1  D2 D3 D4  D5 ...  

Second “Stride Step”:

 A1 A2 A3  A4 A5 ...  >>  A1  A2 A3 A4  A5 ... 
 B1 B2 B3  B4 B5 ...  >>  B1  B2 B3 B4  B5 ... 
[C1 C2 C3] C4 C5 ...  >>  C1 [C2 C3 C4] C5 ... 
[D1 D2 D3] D4 D5 ...  >>  D1 [D2 D3 D4] D5 ... 

In my application, it is crucial that the kernel does not combine any rows other than the pairs AB and CD (i.e. for instance not BC). The setting that I chose was simply stride = (2,1) with ‘valid’ padding (i.e. padding = 0). Just to ensure, are these the correct settings?

Thanks!

Best, JZ

Hi, yes this is correct!

If you use the following code, you can define your own kernel and see what is happening when you do the convolution.

In my case I put a 1 on the lower right corner of the matrix, but you can change it to whatever you want. When you do this, you can see in the image below that you get your desired result.

On the x axis you only advance by 1, whereas on the y axis you only use non-overlapping rows.

Hope this helps!

your_matrix = torch.randn(1, 1, 4, 6)

kernel = torch.tensor([[0., 0., 0.],
                       [0., 0., 1.]]).unsqueeze(0).unsqueeze(0)

conv = nn.Conv2d(1, 1, (2, 3), (2, 1), bias=False)
with torch.no_grad():
    conv.weight = nn.Parameter(kernel)


print(your_matrix)
print(conv(your_matrix))

image

1 Like

very handy trick! thanks for that! Best, JZ

1 Like