Sorry for late answer, here is the idea.
If you look up the definition of multi-channel cross-correlation which is also available in Conv2d docs, you can see below formula:
It says, for each output channel, you need to combine correlation results using sum. In your code, you have removed the correlation between different input channels.
Let’s talk intuitively.
If we have an input tensor X
with size [in_channel, h, w]
and we want to have 3 output channels for the result [3, h, w]
, then we need to convolve x
with a [3, k, k]
kernel 3 times and concatenate the results. But here is the idea, if kernel is same for all the three times, should not the result also be same as both x
and kernel
are identical for each output channel? The answer would be yes which can be achieved by that summation, if we remove that we break this.
Here is your code with some modifications:
torch.manual_seed(0)
spread_kernel = None
kernel_ = None
def spread_dpv(dpv, N=5):
global spread_kernel
global kernel_
dpv_permuted = dpv.clone()
kernel = torch.Tensor(np.zeros((N, N)).astype(np.float32))
kernel[int(N / 2), :] = 1.
kernel = kernel.repeat((dpv_permuted.shape[1], dpv_permuted.shape[1], 1, 1)).to(dpv_permuted.device)
kernel_ = kernel
dpv_permuted = F.conv2d(input=dpv_permuted, weight=kernel, padding=N // 2)
dpv = dpv_permuted
return dpv
spread_kernel = None
def spread_dpv_hack(dpv, N=5):
global spread_kernel
dpv_permuted = dpv.clone()
if spread_kernel is None:
kernel = torch.Tensor(np.zeros((N, N)).astype(np.float32))
kernel[int(N / 2), :] = 1.
kernel = kernel.unsqueeze(0).unsqueeze(0)
kernel = kernel.repeat((1, 1, 1, 1))
kernel = {'weight': kernel.to(dpv_permuted.device), 'padding': N // 2}
spread_kernel = kernel.copy()
for b in range(0, dpv_permuted.shape[0]):
for c in range(0, dpv_permuted.shape[1]):
dpv_permuted[b,c,:,:] = F.conv2d(dpv_permuted[b:b+1,c:c+1], **spread_kernel)
dpv = dpv_permuted
return dpv
x = torch.randint(1, 3, (1, 3, 7, 7)).float()
s = spread_dpv(x)
sh = spread_dpv_hack(x)
modifications:
- PyTorch uses channel first conv, so you should remove
.permute
lines unless explicitly tell conv that is channel last.
- For simplicity I removed normalization parts.
- You were creating different kernel for first method by dividing
/ float(N)
which had not been used for second method.
If you take sh
which is the output of your method, then sum wrt channels you will get a output channel for first dimension of s
.
np.sum(sh.numpy(), axis=(0,1))
Also, we can express it by summing the way I initalized kernel:
F.conv2d(x[:,0:1], kernel_[0:1, 0:1], padding=2) + F.conv2d(x[:,1:2], kernel_[0:1, 1:2], padding=2) + F.conv2d(x[:,2:3], kernel_[0:1, 2:3], padding=2)