Image to vectors operation -- how to calculate a receptive fields tensor

Hi, I’m interested in experimenting with a variant of the standard 2D convolution, for which it would be very convenient to express the convolution operation as a tensor multiplication between the weights tensor and a tensor that contains all the receptive fields as vectors.

That is, I’m looking for an image-to-columns/vectors method that, given the parameters of a 2D convolution (i.e., kernel_size, stride, and padding), it “breaks” the input image (or, for that matter, any appropriate input feature map) into vectors each of which corresponds to a receptive field (according to the aforementioned parameters). By doing so, the convolution operation can be written as a tensor multiplication (or at least I think so).

You could use nn.Unfold to apply the im2col operation and the docs for this method also give you an example how a custom convolution can be applied using this module.

Hi @ptrblck, many thanks for your response. Based on your suggestion and the discussion found here, I figured it out. For the sake of completeness, I give a simple example below:

import torch
import torch.nn as nn
import numpy as np

# Input tensor of shape (bs, C_in, H, W)
bs = 128
C_in = 3
H = 32
W = 32

X = torch.randn(bs, C, H, W)
X = X.double()

# 2D convolution 
C_out = 16
conv = nn.Conv2d(in_channels=C_in, 

# 2D convolution operation =
Y = conv(X)

# ---

# Define nn.Unfold operation
unfold = nn.Unfold(kernel_size=conv.kernel_size, 

# Define nn.Fold operation
h_in, w_in = X.shape[2], X.shape[3]
h_out = (h_in + 2 * conv.padding[0] - (conv.kernel_size[0] - 1) - 1) / conv.stride[0] + 1
w_out = (w_in + 2 * conv.padding[1] - (conv.kernel_size[1] - 1) - 1) / conv.stride[1] + 1
h_out, w_out = int(h_out), int(w_out)
fold = nn.Fold(output_size=(h_out, w_out), kernel_size=1)

W = conv.weight.double()
Y_ = (unfold(X).transpose(1, 2).matmul(W.view(W.size(0), -1).t()) + conv.bias).transpose(1, 2)
Y__ = fold(Y_)
print("torch.allclose(Y, Y__): {}".format(torch.allclose(Y, Y__)))

1 Like