How to add multiple return value in Bottleneck.forward in ResNet?

I wanna add another return value in function Bottleneck.forward. It’s easy to add one, but how can I get this value outside the net?

for example,

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        x = self.shortcut(x)
        out += x
        out = F.relu(out)
        return out, x

Then where to use the returned value x in the class ResNet? Anyone can show me the magic?

Probably you would have to change the code a bit and skip the creation of the nn.Sequential modules (line of code).

Do you need x for the training, i.e. are you feeding it to another layer, or is it just for debugging purposes?
If the latter is true, you could use hooks to get the activation.

Thanks for your reply!
Yes, I need x for training thus I’ve considered “hook”, but it doesn’t fit my purpose.
Can you elaborate more about the modification?

Sure. Let’s first see if we can scale down the problem.
Would it be sufficient, if we just use BottleNeck modules or du you need the flexibility of BasicBlock/BottleNeck?
Also, do you need the different ResNet architectures, i.e. resnet18, resnet34, ... or can we focus on just one implementation?

Thanks.
Bottleneck +ResNet-101,they’re enough for me.

OK, and what would you like to do with x? Should the Bottleneck layers be fed with now with two inputs?

OK.
I want to compute a regularization term using x or out, i.e. torch.norm(x) or torch.norm(out). Then sum over them among all BottleNeck, finally add it to loss term as a regularizer.

I think this should be easier with hooks. Why do you think it doesn’t fit your purpose?

activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output
    return hook


model = Bottleneck(3, 3)
model.shortcut.register_forward_hook(get_activation('bottleneck1'))
x = torch.randn(1, 3, 224, 224)
output = model(x)
print(activation['bottleneck1'])

loss = torch.norm(activation['bottleneck1']) 

Cool. Just register the hook on shotcut does help. So if I want to pass out any feature map inside the bottlenek, I just register it after the specific layer, am I right?
i.e. register_forward_hook on Bottleneck.bn3 can give me the final feature map of Bottleneck?
But one more question, what if I want this out before this line? Right after computing out += residual

You could create a dummy layer and register the hook to it:


class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x
    

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )
        
        self.identity = Identity()

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        x = self.shortcut(x)
        out += x
        out = self.identity(out)
        out = F.relu(out)
        return out

model = Bottleneck(3, 3)
model.identity.register_forward_hook(get_activation('identity'))
1 Like

Sorry for my late reply.
Well, I appreciate to your solution and your imagination very much, :smile

ptrblck, I met a troublesome bug.
Since the feature map of each layer is distributed on different GPUs when using DataParallel. How can I add (or other manipulation) them up? An error is got:
RuntimeError: arguments are located on different GPUs ?
I’ve searched this problem but couldn’t found a solution.

You could push them on the default GPU and add them together:

loss = torch.norm(activation['bottleneck'].to('cuda:0'))

, if you would like to sum them on GPU0.

If this doesn’t work, could you share a small code sample, so that I can have a look?