How to visualize CNN filter?

I wanna visualize filters from model.conv1 on the final epoch. Thanks. My CNN network looks like this.

class CNN(nn.Module):
def init(self):
super(CNN, self).init()
self.conv1 = nn.Conv2d(1, 40, 5, 1)
self.conv2 = nn.Conv2d(40, 80, 5, 1)
self.fc1 = nn.Linear(4480, 500)
self.fc2 = nn.Linear(500, 10)

def forward(self, x):
    x = F.relu(self.conv1(x))
    x = F.max_pool2d(x, 2, 2)
    x = F.relu(self.conv2(x))
    x = F.max_pool2d(x, 2, 2)
    x = x.view(-1,4*4*80)
    x = F.relu(self.fc1(x))
    x = self.fc2(x)
    return F.log_softmax(x, dim=1)

You can directly access the filters via:

filters = model.conv1.weight

and then visualize it with e.g. matplotlib.pyplot.imshow.
Note that filters might be multi-dimensional, so you might need to plot each channel in a separate subplot.
In your model, self.conv1 contains one input channel, so you should be able to visualize each filer next to each other.

1 Like

Thank you! It works. I tend to do this:

filters = model.conv1.weight
for i in range(40):

And I got 40 .pngs. I find the pics are too smaill. Is there any way to make it bigger? Thanks.

You could interpolate the filters first e.g. via F.interpolate(..., mode='nearest').
If you specify mode='nearest', it’ll make sure to repeat the values instead of e.g. interpolate them linearly.

Alternatively you could also use matplotlib or PIL to store the filters, which would provide an argument for the desired shape.

You can see a complete implementation here



1 Like