Manual Implementation of Unrolled 3D Convolutions

Given that torch.nn.Unfold can be used to unroll 2D convolutions, so that they can be computed using Vector Matrix Multiplication (VMMs), and that the same unrolling approach can be used to compute 3D convolutions as VMMs (described in https://www.mdpi.com/2079-9292/8/1/65/pdf), how can PyTorch be used to unroll 3D convolutions?

There is currently no native implementation of torch.nn.Unfold for 5D tensors, however, a response on https://github.com/pytorch/pytorch/issues/30798 indicates that x.unfold(2, k, s).unfold(3, k, s).unfold(4, k, s) can be used. Unfortunately, using this approach, I can’t seem to get the same result as using torch.nn.Conv3d when the same filter weights are used and biases are disabled.

Any help would be greatly appreciated!

Could you post your code so that we could have a look, why it might give you a different result? :slight_smile:

Sure! Please see the code below. The 2D Convolution block appears to work well. I have since managed to implement the 3D Convolution block using a similar approach, however, it only appears to function when padding=(0, 0, 0). I’m using torch.Tensor.unfold to unfold 5D input tensors. Unfortunately, it does not have a padding argument.

Do you have any idea on how I could extrapolate the 3D Convolution block below to support arbitrary padding lengths?

UPDATE I have managed to use torch.nn.functional.pad to manually zero-pad the input before unfolding it, and it appears to be working for arbitrary padding lengths. I have updated the code below.

I believe this has now been resolved :slight_smile:

import torch
from torch import nn
import torch.nn.functional as F


# 2D Convolution
input_dim = [4, 4]
in_channels = 2
out_channels = 4
kernel_size = (3, 3)
padding = (1, 1)
stride = (1, 1)
input_tensor = torch.zeros(1, in_channels, input_dim[0], input_dim[1]).uniform_(-1, 1)
conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding, stride=stride, bias=False)
output_tensor = conv(input_tensor)
output_dim = [0, 0]
output_dim[0] = int((input_tensor.shape[2] - kernel_size[0] + 2 * padding[0]) / stride[0]) + 1
output_dim[1] = int((input_tensor.shape[3] - kernel_size[1] + 2 * padding[1]) / stride[1]) + 1
unfolded_input_tensor = F.unfold(input_tensor, kernel_size=kernel_size, padding=padding, stride=stride)
kernels_flat = conv.weight.detach().clone().view(out_channels, -1)
alt_output_tensor = torch.matmul(kernels_flat, unfolded_input_tensor).view(1, out_channels, output_dim[0], output_dim[1])
print(torch.all(torch.isclose(output_tensor, alt_output_tensor)))

# 3D Convolution
input_dim = [22, 59, 114]
in_channels = 1
out_channels = 16
kernel_size = (22, 5, 5)
padding = (1, 2, 2)
stride = (1, 2, 2)
input_tensor = torch.zeros(1, in_channels, input_dim[0], input_dim[1], input_dim[2]).uniform_(-1, 1)
conv = nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding, stride=stride, bias=False)
output_tensor = conv(input_tensor)
output_dim = [0, 0, 0]
output_dim[0] = int((input_tensor.shape[2] - kernel_size[0] + 2 * padding[0]) / stride[0]) + 1
output_dim[1] = int((input_tensor.shape[3] - kernel_size[1] + 2 * padding[1]) / stride[1]) + 1
output_dim[2] = int((input_tensor.shape[4] - kernel_size[2] + 2 * padding[2]) / stride[2]) + 1
if not all(item == 0 for item in padding):
    input_tensor = F.pad(input_tensor, pad=(padding[2], padding[2], padding[1], padding[1], padding[0], padding[0]), mode="constant", value=0)

unfolded_input_tensor = input_tensor.unfold(2, kernel_size[0], stride[0]).unfold(3, kernel_size[1], stride[1]).unfold(4, kernel_size[2], stride[2])
unfolded_input_tensor = unfolded_input_tensor.reshape(-1, kernel_size[0] * kernel_size[1] * kernel_size[2])
kernels_flat = conv.weight.detach().clone().view(out_channels, -1)
alt_output_tensor = torch.matmul(kernels_flat, unfolded_input_tensor.T).view(1, out_channels, output_dim[0], output_dim[1], output_dim[2])
print(torch.all(torch.isclose(output_tensor, alt_output_tensor)))

Thanks for sharing your code which helps a lot!
However, I met some problems when I trying to modifying the in_channel.
For example, I set
input_dim = [5, 7, 7]
in_channels = 2
out_channels = 4
kernel_size = (4, 6, 6)
padding = (1, 1, 1)
stride = (1, 1, 1)
then it shows

RuntimeError: size mismatch, m1: [4 x 288], m2: [144 x 128] at /Users/distiller/project/conda/conda-bld/pytorch_1570710797334/work/aten/src/TH/generic/THTensorMath.cpp:197

then I set
unfolded_input_tensor = unfolded_input_tensor.reshape(-1, in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2])
since the kernel shape is out_c*in_c*k1*k2*k3
Eventhough I pass the error, the final result is tensor(False)

I would like to know how to unroll Conv3d in multi-input channels.
Thanks for your time to answer the question!

The above code does not seem to work when the in_channel is bigger than 1.

Here is my approach which seems to work well.


    import numpy as np
    import torch
    from torch import nn
    import torch.nn.functional as F

    inputShape = [128, 128, 128]    
    batchSioze = 2
    CIn        = 4
    COut       = 8    
    kernelSize = (10,5,3)
    pad        = (2,3,1)
    stride     = (1,2,1)

    # normal conv
    conv = nn.Conv3d(CIn, COut, kernelSize, stride, pad, bias=False).cuda()
                      
    # alternativeConv
    def alternativeConv(X, K, 
                        COut       = None,
                        kernelSize = (3,3,3),
                        pad        = (1,1,1),
                        stride     = (1,1,1) ):

        def unfold3d(tensor, kernelSize, pad, stride): 

            B, C, _, _, _ = tensor.shape

            # Input shape: (B, C, D, H, W)
            tensor = F.pad(tensor,
                           (pad[2], pad[2],
                            pad[1], pad[1],
                            pad[0], pad[0])
                          )

            tensor = (tensor
                      .unfold(2, size=kernelSize[0], step=stride[0])
                      .unfold(3, size=kernelSize[1], step=stride[1])
                      .unfold(4, size=kernelSize[2], step=stride[2])
                      .permute(0, 2, 3, 4, 1, 5, 6, 7)
                      .reshape(B, -1, C * np.prod(kernelSize))
                      .transpose(1, 2)
                     )
            
            return tensor
    
        B,CIn,H,W,D = X.shape
        outShape = ( (np.array([H,W,D]) - np.array(kernelSize) + 2 * np.array(pad)) / np.array(stride) ) + 1
        outShape = outShape.astype(np.int32)
        
        X = unfold3d(X, kernelSize, pad, stride)
  
        K = K.view(COut, -1)
        #K = torch.randn(COut, CIn, *kernelSize).cuda() 
        #K = K.view(COut, -1)
                    
        Y = torch.matmul(K, X).view(B, COut, *outShape)
        
        return Y
    
    X = torch.randn(batchSioze, CIn, *inputShape).cuda()
    
    Y1 = conv(X)
    
    Y2 = alternativeConv(X, conv.weight, 
                         COut       = COut,
                         kernelSize = kernelSize,
                         pad        = pad,
                         stride     = stride
                         )
    
    print(torch.all(torch.isclose(Y1, Y2)))   
1 Like

Hi,

not sure if you’re still looking for a solution, but I have recently written a small package named unfoldNd that should do what you want.

unfoldNd.UnfoldNd works exactly like torch.nn.Unfold, also for 5d inputs. Here is a small demo.

Hope it’s useful.

2 Likes