Manual Implementation of Unrolled 3D Convolutions

The above code does not seem to work when the in_channel is bigger than 1.

Here is my approach which seems to work well.


    import numpy as np
    import torch
    from torch import nn
    import torch.nn.functional as F

    inputShape = [128, 128, 128]    
    batchSioze = 2
    CIn        = 4
    COut       = 8    
    kernelSize = (10,5,3)
    pad        = (2,3,1)
    stride     = (1,2,1)

    # normal conv
    conv = nn.Conv3d(CIn, COut, kernelSize, stride, pad, bias=False).cuda()
                      
    # alternativeConv
    def alternativeConv(X, K, 
                        COut       = None,
                        kernelSize = (3,3,3),
                        pad        = (1,1,1),
                        stride     = (1,1,1) ):

        def unfold3d(tensor, kernelSize, pad, stride): 

            B, C, _, _, _ = tensor.shape

            # Input shape: (B, C, D, H, W)
            tensor = F.pad(tensor,
                           (pad[2], pad[2],
                            pad[1], pad[1],
                            pad[0], pad[0])
                          )

            tensor = (tensor
                      .unfold(2, size=kernelSize[0], step=stride[0])
                      .unfold(3, size=kernelSize[1], step=stride[1])
                      .unfold(4, size=kernelSize[2], step=stride[2])
                      .permute(0, 2, 3, 4, 1, 5, 6, 7)
                      .reshape(B, -1, C * np.prod(kernelSize))
                      .transpose(1, 2)
                     )
            
            return tensor
    
        B,CIn,H,W,D = X.shape
        outShape = ( (np.array([H,W,D]) - np.array(kernelSize) + 2 * np.array(pad)) / np.array(stride) ) + 1
        outShape = outShape.astype(np.int32)
        
        X = unfold3d(X, kernelSize, pad, stride)
  
        K = K.view(COut, -1)
        #K = torch.randn(COut, CIn, *kernelSize).cuda() 
        #K = K.view(COut, -1)
                    
        Y = torch.matmul(K, X).view(B, COut, *outShape)
        
        return Y
    
    X = torch.randn(batchSioze, CIn, *inputShape).cuda()
    
    Y1 = conv(X)
    
    Y2 = alternativeConv(X, conv.weight, 
                         COut       = COut,
                         kernelSize = kernelSize,
                         pad        = pad,
                         stride     = stride
                         )
    
    print(torch.all(torch.isclose(Y1, Y2)))   
1 Like