Forward hook activations for loss computation

@Alexander_Riedel a very interesting use case. I would definitely like to know the reasons why you want the same feature map for two successive conv layers

You will have to make some tweaks to the code

  • For L1 Loss, both the outputs need to be the same, so you need to ensure that the number of channels are the same

  • You need to resize the smaller width, height to the larger width, height so that you can pass that to the L1 Loss. You can leverage torch resize for this

  • In your hook function, you will need to remove detach() as it removes the element from the graph and you will not be able to use it for backward propagation any more

Following is a sample code.Changed the weight sizes a bit to make it working

import numpy as np
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import matplotlib.pyplot as plt
import cv2 
import torchvision.transforms.functional as F1

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 6, 5)
        self.fc1 = nn.Linear(48, 128)
        self.fc2 = nn.Linear(128, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        print(x.size())
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output
    return hook

model = Net()

model.conv1.register_forward_hook(get_activation("conv1"))
model.conv2.register_forward_hook(get_activation("conv2"))

x = torch.rand(6, 3, 28, 20)

model(x)
print(activation['conv1'].size())
print(activation['conv2'].size())
l1_loss = torch.nn.L1Loss()
loss = l1_loss(activation["conv1"], F1.resize(activation["conv2"], size=(24, 16)))

loss.mean().backward()
1 Like