Consider torch.nn.Conv2d( 3, 64, 3) applied to a single 32 x 32 RGB (3 channel) image.
Q1: I believe that means there is a separate kernel for each of the 3 input channels? In other
words, each R, G and B color has a different kernel for its own 32 x 32 grid. So you
have a triplet of kernels!?
Q2: Are the outputs for each of those three kernels somehow combined? How? I’m picture
3 output grids that need to be combined into one final grid of the same size?
Q3: Does the 64 mean you have 64 different kernel triplets for a total of
3 x 64 different kernels applied to the 32 x 32 grids?
Thanks,
Chris
Thank you very much. So if I understand you correctly, after applying a kernel, the output will have one less number of dimensions?
For example, if you apply a kernel to an image with dimensions W x H x N, the output will have dimensions w <= W and h <= H?
(Assume N is the number of channels.)
No, the number of dimensions will be the same. The conv layer expects an input of [batch_size, in_channels, height0, width0]
and creates and output of [batch_size, out_channels, height1, width1]
, where the spatial dimension sizes are defined by the kernel size, the stride, dilation, padding, etc. as explained in the docs.
ptrblck
Thanks again. If can please ask one more question…
Using your notation, would it be correct to say the
conv layer applies multiple kernels to the input
where the number of kernels applied is out_channels?
Chris
Yes, this is the case in the default setup (not for depthwise or grouped convs) and you can verify it by applying each kernel separately:
x = torch.randn(2, 3, 224, 224)
conv = nn.Conv2d(3, 10, 3, 1, 1, bias=False)
ref = conv(x)
# apply each kernel separately
out = []
for kernel in conv.weight:
kernel = kernel.unsqueeze(0)
print(kernel.shape)
tmp = F.conv2d(x, kernel, bias=None, stride=1, padding=1)
out.append(tmp)
out = torch.cat(out, dim=1)
print(out.shape)
# torch.Size([2, 10, 224, 224])
print((ref - out).abs().max())
# tensor(0., grad_fn=<MaxBackward1>)
Wonderful! Thanks so much. This was really confusing me.
cs