Manual Implementation of Unrolled 3D Convolutions

Sure! Please see the code below. The 2D Convolution block appears to work well. I have since managed to implement the 3D Convolution block using a similar approach, however, it only appears to function when padding=(0, 0, 0). I’m using torch.Tensor.unfold to unfold 5D input tensors. Unfortunately, it does not have a padding argument.

Do you have any idea on how I could extrapolate the 3D Convolution block below to support arbitrary padding lengths?

UPDATE I have managed to use torch.nn.functional.pad to manually zero-pad the input before unfolding it, and it appears to be working for arbitrary padding lengths. I have updated the code below.

I believe this has now been resolved :slight_smile:

import torch
from torch import nn
import torch.nn.functional as F


# 2D Convolution
input_dim = [4, 4]
in_channels = 2
out_channels = 4
kernel_size = (3, 3)
padding = (1, 1)
stride = (1, 1)
input_tensor = torch.zeros(1, in_channels, input_dim[0], input_dim[1]).uniform_(-1, 1)
conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding, stride=stride, bias=False)
output_tensor = conv(input_tensor)
output_dim = [0, 0]
output_dim[0] = int((input_tensor.shape[2] - kernel_size[0] + 2 * padding[0]) / stride[0]) + 1
output_dim[1] = int((input_tensor.shape[3] - kernel_size[1] + 2 * padding[1]) / stride[1]) + 1
unfolded_input_tensor = F.unfold(input_tensor, kernel_size=kernel_size, padding=padding, stride=stride)
kernels_flat = conv.weight.detach().clone().view(out_channels, -1)
alt_output_tensor = torch.matmul(kernels_flat, unfolded_input_tensor).view(1, out_channels, output_dim[0], output_dim[1])
print(torch.all(torch.isclose(output_tensor, alt_output_tensor)))

# 3D Convolution
input_dim = [22, 59, 114]
in_channels = 1
out_channels = 16
kernel_size = (22, 5, 5)
padding = (1, 2, 2)
stride = (1, 2, 2)
input_tensor = torch.zeros(1, in_channels, input_dim[0], input_dim[1], input_dim[2]).uniform_(-1, 1)
conv = nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding, stride=stride, bias=False)
output_tensor = conv(input_tensor)
output_dim = [0, 0, 0]
output_dim[0] = int((input_tensor.shape[2] - kernel_size[0] + 2 * padding[0]) / stride[0]) + 1
output_dim[1] = int((input_tensor.shape[3] - kernel_size[1] + 2 * padding[1]) / stride[1]) + 1
output_dim[2] = int((input_tensor.shape[4] - kernel_size[2] + 2 * padding[2]) / stride[2]) + 1
if not all(item == 0 for item in padding):
    input_tensor = F.pad(input_tensor, pad=(padding[2], padding[2], padding[1], padding[1], padding[0], padding[0]), mode="constant", value=0)

unfolded_input_tensor = input_tensor.unfold(2, kernel_size[0], stride[0]).unfold(3, kernel_size[1], stride[1]).unfold(4, kernel_size[2], stride[2])
unfolded_input_tensor = unfolded_input_tensor.reshape(-1, kernel_size[0] * kernel_size[1] * kernel_size[2])
kernels_flat = conv.weight.detach().clone().view(out_channels, -1)
alt_output_tensor = torch.matmul(kernels_flat, unfolded_input_tensor.T).view(1, out_channels, output_dim[0], output_dim[1], output_dim[2])
print(torch.all(torch.isclose(output_tensor, alt_output_tensor)))