The above code does not seem to work when the in_channel is bigger than 1.
Here is my approach which seems to work well.
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
inputShape = [128, 128, 128]
batchSioze = 2
CIn = 4
COut = 8
kernelSize = (10,5,3)
pad = (2,3,1)
stride = (1,2,1)
# normal conv
conv = nn.Conv3d(CIn, COut, kernelSize, stride, pad, bias=False).cuda()
# alternativeConv
def alternativeConv(X, K,
COut = None,
kernelSize = (3,3,3),
pad = (1,1,1),
stride = (1,1,1) ):
def unfold3d(tensor, kernelSize, pad, stride):
B, C, _, _, _ = tensor.shape
# Input shape: (B, C, D, H, W)
tensor = F.pad(tensor,
(pad[2], pad[2],
pad[1], pad[1],
pad[0], pad[0])
)
tensor = (tensor
.unfold(2, size=kernelSize[0], step=stride[0])
.unfold(3, size=kernelSize[1], step=stride[1])
.unfold(4, size=kernelSize[2], step=stride[2])
.permute(0, 2, 3, 4, 1, 5, 6, 7)
.reshape(B, -1, C * np.prod(kernelSize))
.transpose(1, 2)
)
return tensor
B,CIn,H,W,D = X.shape
outShape = ( (np.array([H,W,D]) - np.array(kernelSize) + 2 * np.array(pad)) / np.array(stride) ) + 1
outShape = outShape.astype(np.int32)
X = unfold3d(X, kernelSize, pad, stride)
K = K.view(COut, -1)
#K = torch.randn(COut, CIn, *kernelSize).cuda()
#K = K.view(COut, -1)
Y = torch.matmul(K, X).view(B, COut, *outShape)
return Y
X = torch.randn(batchSioze, CIn, *inputShape).cuda()
Y1 = conv(X)
Y2 = alternativeConv(X, conv.weight,
COut = COut,
kernelSize = kernelSize,
pad = pad,
stride = stride
)
print(torch.all(torch.isclose(Y1, Y2)))