Extraction of activations from a layer

(Muammar El Khatib) #1

I have a model and would like to regularize the activations of one of its layers. I wondered if I had to do something “special” to get them or just returning the layer I need from forward() would work?

I would appreciate if anyone could give me a hint about it. I probably am overthinking this.

#2

Doesn’t torch.nn.BatchNorm fit your request? Otherwise, you can always break your forward method into many other methods to get the intermediate results, and normalize them as you wish. The most complex case could come from modules that you didn’t implement yourself, in that case, you can always override this module’s forward method. For instance:

import torch
from torchvision.models.resnet import ResNet, BasicBlock, resnet18
class NewResNet(ResNet):
    def __init__(self, **kw):
        super(NewResNet, self).__init__(**kw)
    def forward_1(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        return x
    def forward(self, x):
        x = self.forward_1(x)
        # Potentially a new normalization could happen here ...
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# A little check to validate that we didn't break the module
torch.manual_seed(2019)
net = NewResNet(block=BasicBlock, layers=[2, 2, 2, 2])
torch.manual_seed(2019)
net_check = resnet18(pretrained=False)

img = torch.rand((1, 3, 224, 224))
res1 = net_check(img).sum()
res2 = net(img).sum()

print(res1, res2) # Same results