Visualizing the outputs of kernels, instead of the outputs of filters

Hello

As I have seen, in many posts in the forum, peoples have explained about visualizing featuermaps(i.e., the outputs of the filters.)
As the mentioned featuremaps are the summation of all outputs of kernels, my question is how we can access to these featuremaps outputted from kernels, right before summing up to produce final featuremaps?

Thank you

Each “featuremap” is created by a single kernel. I.e. a conv layer using the setup nn.Conv2d(in_channels=3, out_channels=10, kernel_size=3), the output activation (featuremap) will have the shape [batch_size, out_channels=10, height, width]. The out_channels dimension corresponds to the number of filters/kernels in the layer (in the default setup) and each filter creates one activation map.

I guess this refers to the applied operation of each filter on the input activation. If you want to split the channels of the filter and apply them separately, you could use the groups argument during the instantiation of the conv layer.

Hi @ptrblck

Thank you for your explanations.
In my explanations, I assume kernels as the axes/dimension for each filter.
Based on your example, each one of 10 featuremaps refers to a filter which contains three kernels(because of the input dimension.)
In other words, during the Conv. operation, 30 featuremaps are produces, afterwards they are summarized to 10 featuremaps.
I meant, instead of accessing 10 featuremaps using hook function, accessing those 30 featuremaps.
Is groups argument that you mentioned a solution to what I want(30 featuremaps). If so, is it possible to apply such operation on pre-trained CNN such as VGG19? or shall I train it from scratch with the Conv. layers including groups argument?

Thank you very much.

Filters and kernels are usually referring to the same object, but if I understand you correctly you are calling a 3D tensor a filter and each 2D slice a kernel?

If so, it seems you would like to get the activations after each “filter” was applied to the corresponding input activation channel and after the spatial dimensions (height and width) were already reduced, BUT before the channel dimension is reduced?

It would be possible to achieve this using grouped convolutions, but you would need to perform a lot of reshaping and would need to know how the conv kernels are applied internally.
For grouped convs, take a look at this post first as it gives you some information about the order the kernels are applied.

Let’s assume you are working with an input activation in the shape [batch_size=1, channels=4, height=24, width=24] and a standard conv layer with a weight tensor in the shape [out_channels=8, in_channels=4, height=3, width=3]. To keep the spatial dimension equal, let’s also use padding=1 and stride=1. Let’s also set the bias to None to make it a bit easier.
Now, you could use a grouped convolution via nn.Conv2d(4, 8*4, 3, 1, 1, groups=4), however you would need to be careful about the order how these filters are applied.

The main issue are the needed permutations needed in the filters as the grouped conv is applied sequentially on each input channel as seen in the linked post, while the standard conv applies a “full” kernel with all its channels to the input, so all input channels will be used on each filter.
It’s a bit hard to describe in words and although my drawing skills are not perfect, here is an illustration:

And here the corresponding code:

import torch
import torch.nn as nn

N, C, H, W = 1, 4, 24, 24

for _ in range(100):
    N, C, H, W = torch.randint(1, 100, (4,))
    
    x = torch.rand(N, C, H, W)
    
    out_channels = 8
    kw = 3
    conv = nn.Conv2d(C, out_channels, kw, 1, 1, bias=False)
    out = conv(x)
    #print(out.shape)
    # > torch.Size([1, 8, 24, 24])
    
    conv_grouped = nn.Conv2d(C, C*out_channels, kw, 1, 1, groups=C, bias=False)
    with torch.no_grad():
        conv_grouped.weight.copy_(conv.weight.permute(1, 0, 2, 3).reshape(C*out_channels, 1, kw, kw))
    out_grouped = conv_grouped(x)
    #print(out_grouped.shape)
    # > torch.Size([1, 32, 24, 24])
    
    out_grouped = out_grouped.view(N, C, out_channels, H, W).permute(0, 2, 1, 3, 4).reshape(N, C*out_channels, H, W)
    
    # manually reduce
    idx = torch.arange(out_channels)
    idx = torch.repeat_interleave(idx, C)
    idx = idx[None, :, None, None].expand(N, -1, H, W)
    out_grouped_reduced = torch.zeros_like(out)
    out_grouped_reduced.scatter_add_(dim=1,index=idx, src=out_grouped)
    
    # check error
    print(torch.allclose(out_grouped_reduced, out, atol=5e-6), (out_grouped_reduced - out).abs().max())

As you can see, the outputs are equal after the “manual” reduction.

You could manipulate the pretrained VGG19 by replacing all conv layers with their grouped equivalent. I think writing a custom conv layer implementation would be beneficial given the complexity of the code.

Hi @ptrblck

First of all, I am thankful for your explanations and providing the links and image which really helped me to understand how grouped option works.

Kindly, there is still something unclear for me.
If I am not wrong, based on your scheme, it is possible that channels from different filter are grouped. Is that right? Because in the standard one each filter has 4 channel with size 3,3, while in the grouped one each filter has 8 channel with size 3,3.
could you please take a look in this link, where the an example of convolutional operation is explained.
Here is part of that link:
"Here the input layer is a 5 x 5 x 3 matrix, with 3 channels. The filter is a 3 x 3 x 3 matrix. First, each of the kernels in the filter are applied to three channels in the input layer, separately. Three convolutions are performed, which result in 3 channels with size 3 x 3. Then these three channels are summed together (element-wise addition) to form one single channel (3 x 3 x 1). This channel is the result of convolution of the input layer (5 x 5 x 3 matrix) using a filter (3 x 3 x 3 matrix)."

Then, based on your standard convolutional layer, each filter has shape [4,3,3] such that each [1,3,3] is separately applied on each [1,24,24] input activation. When convoultional operation is done, the output will be [4,24,24] for one filter, afterwards, they are summed across channel dimension to output activation with shape [1,24,24]. Am I right?
and what I need is the activations for each filter with shape [4,24,24] instead of [1,24,24].

Your explanation sounds right and fits into the standard conv approach shown in the attached picture.

My code snippet provides exactly this with additionally the channel dim reduction to verify that the outputs are equal.

Dear @ptrblck

A few months ago, I created this topic.
I re-implemented your code snippet for replacing the standard conv. layer in Alexnet with the new one discussed in this topic. However, when I use torch.allclose function (implemented in your code) to figure out whether the output of the standard conv. layer and modified one is equal or not, the result is always False, even the program is run without any error. I mean the input is passed through all layer without error but the output of torch.allclose is always False. The difference between your code and my code is that in your code you considered bias parameter to be false for simplification. My code also considers the bias values.
I provide my code also. I really appreciate it if you could please help out. What is wrong with my code that the out puts are different?

Thank you.

The code snippet for modifying the conv layer.

for name, module in alex_net._modules.items():
if name == ‘features’:
for n,m in module._modules.items():
if type(m) == torch.nn.modules.conv.Conv2d:
alex_net._modules[name][int(n)] = custom_conv(alex_net._modules[name][int(n)])

print(alex_net)

output: I keep the standard conv. layer in the custom_conv class for comparing their outputs.

AlexNet(
(features): Sequential(
(0): custom_conv(
(original_conv): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
(conv_grouped): Conv2d(3, 192, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2), groups=3)
)
(1): ReLU(inplace=True)
(2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
(3): custom_conv(
(original_conv): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(conv_grouped): Conv2d(64, 12288, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=64)
)
(4): ReLU(inplace=True)
(5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
(6): custom_conv(
(original_conv): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv_grouped): Conv2d(192, 73728, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192)
)
(7): ReLU(inplace=True)
(8): custom_conv(
(original_conv): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv_grouped): Conv2d(384, 98304, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384)
)
(9): ReLU(inplace=True)
(10): custom_conv(
(original_conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv_grouped): Conv2d(256, 65536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256)
)
(11): ReLU(inplace=True)
(12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
(classifier): Sequential(
(0): Dropout(p=0.5, inplace=False)
(1): Linear(in_features=9216, out_features=4096, bias=True)
(2): ReLU(inplace=True)
(3): Dropout(p=0.5, inplace=False)
(4): Linear(in_features=4096, out_features=4096, bias=True)
(5): ReLU(inplace=True)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)

Custom conv. layer implementations.

class custom_conv(nn.Module):
def init(self, conv_module):
super(custom_conv, self).init()
self.original_conv = conv_module
self.in_channels = conv_module.in_channels
self.out_channels = conv_module.out_channels
self.kernel_size = conv_module.kernel_size
self.stride = conv_module.stride
self.padding = conv_module.padding
self.conv_grouped = nn.Conv2d(self.in_channels, self.in_channels*self.out_channels,self.kernel_size,
self.stride, self.padding, groups=self.in_channels, bias=True)
with torch.no_grad():
self.conv_grouped.weight.copy_(conv_module.weight.permute(1, 0, 2, 3).
reshape(self.in_channels * self.out_channels, 1, self.kernel_size[0], self.kernel_size[1]))

def __call__(self, x):
    out = self.original_conv(x)
    #print('original output ',out.shape)
    N = out.shape[0]
    H = out.shape[2]
    W = out.shape[3]
    out_grouped = self.conv_grouped(x)
    #print(out_grouped.shape)
    out_grouped = out_grouped.view(N, self.in_channels , self.out_channels, H, W).\
        permute(0, 2, 1, 3, 4).reshape(N, self.in_channels*self.out_channels, H, W)
    idx = torch.arange(self.out_channels)
    idx = torch.repeat_interleave(idx, self.in_channels)
    idx = idx[None, :, None, None].expand(N, -1, H, W)
    out_grouped_reduced = torch.zeros_like(out)
    out_grouped_reduced.scatter_add_(dim=1, index=idx, src=out_grouped)
    print(torch.allclose(out_grouped_reduced, out, atol=5e-6), (out_grouped_reduced - out).abs().max())

    return out_grouped_reduced, out

I pass a random input with shape [1,3,224,224] and the output of my code is as follow:

layer index: 0
False tensor(0.2228, grad_fn=)
original shape : torch.Size([1, 64, 55, 55])
grouped shape: torch.Size([1, 64, 55, 55])
layer index: 1
original shape : torch.Size([1, 64, 55, 55])
grouped shape: torch.Size([1, 64, 55, 55])
layer index: 2
original shape : torch.Size([1, 64, 27, 27])
grouped shape: torch.Size([1, 64, 27, 27])
layer index: 3
False tensor(2.6001, grad_fn=)
original shape : torch.Size([1, 192, 27, 27])
grouped shape: torch.Size([1, 192, 27, 27])
layer index: 4
original shape : torch.Size([1, 192, 27, 27])
grouped shape: torch.Size([1, 192, 27, 27])
layer index: 5
original shape : torch.Size([1, 192, 13, 13])
grouped shape: torch.Size([1, 192, 13, 13])
layer index: 6
False tensor(10.5771, grad_fn=)
original shape : torch.Size([1, 384, 13, 13])
grouped shape: torch.Size([1, 384, 13, 13])
layer index: 7
original shape : torch.Size([1, 384, 13, 13])
grouped shape: torch.Size([1, 384, 13, 13])
layer index: 8
False tensor(10.4743, grad_fn=)
original shape : torch.Size([1, 256, 13, 13])
grouped shape: torch.Size([1, 256, 13, 13])
layer index: 9
original shape : torch.Size([1, 256, 13, 13])
grouped shape: torch.Size([1, 256, 13, 13])
layer index: 10
False tensor(10.4024, grad_fn=)
original shape : torch.Size([1, 256, 13, 13])
grouped shape: torch.Size([1, 256, 13, 13])
layer index: 11
original shape : torch.Size([1, 256, 13, 13])
grouped shape: torch.Size([1, 256, 13, 13])
layer index: 12
original shape : torch.Size([1, 256, 6, 6])
grouped shape: torch.Size([1, 256, 6, 6])
layer index: 12
original shape : torch.Size([1, 256, 6, 6])
grouped shape: torch.Size([1, 256, 6, 6])
layer index: 12
original shape : torch.Size([1, 1000])
grouped shape: torch.Size([1, 1000])

I don’t know how you are processing the bias but you would either have to scale it with 1/C if you are using the bias directly in the grouped conv or you could add it after the reduction as seen here:

import torch
import torch.nn as nn

N, C, H, W = 1, 4, 24, 24

for _ in range(100):
    N, C, H, W = torch.randint(1, 100, (4,))
    
    x = torch.rand(N, C, H, W)
    
    out_channels = 8
    kw = 3
    conv = nn.Conv2d(C, out_channels, kw, 1, 1, bias=True)
    out = conv(x)
    #print(out.shape)
    # > torch.Size([1, 8, 24, 24])
    
    conv_grouped = nn.Conv2d(C, C*out_channels, kw, 1, 1, groups=C, bias=True)
    with torch.no_grad():
        conv_grouped.weight.copy_(conv.weight.permute(1, 0, 2, 3).reshape(C*out_channels, 1, kw, kw))
        conv_grouped.bias.copy_(conv.bias.repeat(C) / C) # scale bias, since it's added after the channel reduction in the plain conv
    out_grouped = conv_grouped(x)
    #print(out_grouped.shape)
    # > torch.Size([1, 32, 24, 24])
    
    out_grouped = out_grouped.view(N, C, out_channels, H, W).permute(0, 2, 1, 3, 4).reshape(N, C*out_channels, H, W)
    
    # manually reduce
    idx = torch.arange(out_channels)
    idx = torch.repeat_interleave(idx, C)
    idx = idx[None, :, None, None].expand(N, -1, H, W)
    out_grouped_reduced = torch.zeros_like(out)
    out_grouped_reduced.scatter_add_(dim=1, index=idx, src=out_grouped)
    
    # alternatively add the bias after the reduction
    #out_grouped_reduced = out_grouped_reduced + conv.bias[None, :, None, None].expand_as(out_grouped_reduced)
    
    # check error
    print(torch.allclose(out_grouped_reduced, out, atol=5e-6), (out_grouped_reduced - out).abs().max())

Thank you @ptrblck for your reply.
If I use the bias in this way for modifying pre-trained CNNs such as Alex on Image-Net, then I would be expected to have the same results for both models(i.e., the standard model and modified model )?

Thank you

Dear @ptrblck
I tried your solution and it worked. the output of torch.allclose function became True.

If you don’t mind, I have one more question. In your code, you use the block
with torch.no_grad():
for defining new conv. layer. it means that the parameter of this new conv.layer is not trainable?
When I remove with torch.no_grad(): I got error. If I want to make it trainable how I should implemented it?

Thank you for your attention given to my issue.

I use the no_grad guard to copy the parameters from the plain conv to the grouped conv layer. The grouped conv layer will still be trainable, but the parameter copy should not be tracked by Autograd (as it’ll raise an error as you’ve already seen).