Unable to visualize feature maps

Here is my conv model, and I’m trying to visualize feature maps during training (in the train() function) based on the answers for another question. Getting the error TypeError: Invalid shape (7,) for image data. Could you please guide me through what I am missing?

class CKNet(nn.Module):
    def __init__(self):
        super(CKNet, self).__init__()

        kernel_size = 3

        self.activation_fn = nn.ReLU(inplace=True)

        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=kernel_size),
            self.activation_fn,
            nn.BatchNorm2d(16),
            # nn.MaxPool2d(2, 2),
            nn.Dropout2d()
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 16, kernel_size=kernel_size),
            self.activation_fn,
            nn.BatchNorm2d(16),
            # nn.MaxPool2d(2, 2),
            nn.Dropout2d()
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=kernel_size),
            self.activation_fn,
            nn.BatchNorm2d(32),
            # nn.MaxPool2d(2, 2),
            nn.Dropout2d()
        )

        self.conv4 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=kernel_size),
            self.activation_fn,
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2, 2),
            nn.Dropout2d()
        )

        self.conv5 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=kernel_size),
            self.activation_fn,
            nn.BatchNorm2d(64),
            # nn.MaxPool2d(2, 2),
            nn.Dropout2d()
        )

        self.conv6 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=kernel_size),
            self.activation_fn,
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2, 2),
            nn.Dropout2d()
        )

        self.conv7 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=kernel_size),
            self.activation_fn,
            nn.BatchNorm2d(128),
            # nn.MaxPool2d(2, 2),
            nn.Dropout2d()
        )

        self.conv8 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=kernel_size),
            self.activation_fn,
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2, 2),
            nn.Dropout2d()
        )

        last_width_height = 4
        conv_out_channels = self.conv8._modules['0'].out_channels
        input_features = last_width_height * last_width_height * conv_out_channels

        self.fc1 = nn.Sequential(
            nn.Linear(input_features, 768),
            self.activation_fn,
            nn.Dropout2d())
        self.fc2 = nn.Sequential(
            nn.Linear(768, 128),
            self.activation_fn,
            nn.Dropout2d())
        self.fc3 = nn.Sequential(
            nn.Linear(128, len(folders)))

    def forward(self, x):
        x = x.to(device)
        return x


def train(epoch):
    epoch_start = time()
    model.train()
    train_total_loss = 0
    index = 0
    for batch_idx, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        output = model(images)  # forward
        output = output.to(device)

        loss = criterion(output, labels)
        # backward
        loss.backward()
        # update weights
        optimizer.step()

        train_total_loss += loss.item()
        train_counter.append((batch_idx * train_batch_size) + (epoch * len(train_loader.dataset)))

        index += 1

        # print statistics
        if batch_idx % log_interval == 0:
            # print(f'batch_idx: {batch_idx}')
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch + 1, (batch_idx + 1) * len(images), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

            train_losses.append(train_total_loss / index)
            print(f'Epoch #{epoch + 1} duration: {time() - epoch_start} seconds')

            # Visualize feature maps
            activation = {}

            def get_activation(name):
                def hook(model, input, output):
                    activation[name] = output.detach()

                return hook

            activation['conv1'] = output.detach()
            model.conv1[0].register_forward_hook(get_activation('conv1'))
            act = activation['conv1'].squeeze().cpu()
            num_plot = 4
            fig, axarr = plt.subplots(min(act.size(0), num_plot))
            for idx in range(min(act.size(0), num_plot)):
                axarr[idx].imshow(act[idx])
            plt.xticks([])
            plt.yticks([])
            plt.show()

def main():
    global train_loader, validation_loader

    global face_train_dataset
    face_train_dataset = datasets.ImageFolder(root=DEST_PATH_TRAIN, transform=data_transforms)

    train_loader = DataLoader(face_train_dataset,
                              batch_size=train_batch_size, shuffle=True,
                              num_workers=4)

    global prediction_counter
    prediction_counter = Counter()

    face_validation_dataset = datasets.ImageFolder(root=DEST_PATH_VALIDATION,
                                                   transform=data_transforms)
    validation_loader = DataLoader(face_validation_dataset,
                                   batch_size=test_batch_size, shuffle=False,
                                   num_workers=4)

    print(f'# of training images: {len(train_loader.dataset)}')
    print(f'# of validation images: {len(validation_loader.dataset)}')

    global model
    model = CKNet().to(device)
    print(f'Model Overview:\n{model}')

    global criterion
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(device)
    global optimizer
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)

    global train_losses, train_counter, validation_losses, validation_counter
    train_losses = []
    train_counter = []
    validation_losses = []
    validation_counter = [i * len(train_loader.dataset) for i in range(n_epochs + 1)]

    global accuracy
    accuracy = 0  # initial value
    best_accuracy = 0
    start_time = time()
    # test()
    global last_epoch
    last_epoch = 0
    for epoch in range(n_epochs):
        train(epoch)
        validation_loss, accuracy = test()
        last_epoch = epoch

1 Like

Very interested on this!

@talhak Firstly a forward hook should be attached to a module before its forward call. In the above code, you first do output = model(images) # forward and then model.conv1[0].register_forward_hook(get_activation('conv1')). Even though this doesn’t call hook() during the first forward call to conv1, subsequent forward calls will invoke hook. Also you register a hook for every batch_idx % log_interval which is not required.

Coming to visualization of activation maps, can you let know the shape of the conv1 output?
Problem might occur because the conv outputs are in the order C, H, W whereas the image plot expects H, W, C(Note: if C=1, then it shouldn’t cause issues since you do squeeze anyway)

If I call it before the forward call, what would be the output? @mailcorahul

And if I hook the visualization registration just before the forward call (as I posted below), getting TypeError: Invalid shape (16, 62, 62) for image data. @mailcorahul

# Visualize feature maps
activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()

    return hook
model.conv1[0].register_forward_hook(get_activation('conv1'))
# forward + backward + optimize
output = model(images)  # forward
output = output.to(device)

# activation['conv1'] = output.detach()

act = activation['conv1'].squeeze().cpu()
num_plot = 4
fig, axarr = plt.subplots(min(act.size(0), num_plot))
for idx in range(min(act.size(0), num_plot)):
    axarr[idx].imshow(act[idx])
plt.xticks([])
plt.yticks([])
plt.show()

The shape (16, 62, 62) I suppose should be (channels, height, width). Iterate over all those 16 feature maps(each of size (62, 62)) and plot them individually.
As I mentioned above, imshow expects (H, W, C) and here C=62 which is why it is throwing error.

1 Like