Get the activations of the second to last layer

Hello! I have this piece of code:

class View(nn.Module):
    def __init__(self, size):
        super(View, self).__init__()
        self.size = size

    def forward(self, tensor):
        return tensor.view(self.size)

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        
        N = 32
        
        self.encoder = nn.Sequential(
            nn.Conv2d(3, N, 4, 2, 1),          
            nn.ReLU(True),
            nn.Conv2d(N, N, 4, 2, 1),          
            nn.ReLU(True),
            nn.Conv2d(N, 2*N, 4, 2, 1),          
            nn.ReLU(True),
            nn.Conv2d(2*N, 2*N, 4, 2, 1),         
            nn.ReLU(True),
            nn.Conv2d(2*N, 8*N, 4, 1),          
            nn.ReLU(True),
            View((-1, 8*N*1*1)),                
            nn.Linear(8*N, 2),            
        )
            
    def forward(self, x):
        z = self.encoder(x)
        return z

I take 64x64 images and return 2 numbers. I trained it to do what I need and it works well, but I would like now (for some other reason) to get the activations before the output i.e. the result of that flattening layer. So I would need the values of the 8*N dimensional vector, before the last matrix multiplication. How can I do this? Thank you!

Hi @smu226,

You should be able to do it with hooks:

global view_output
def hook_fn(module, input, output):
    global view_output
    view_output = output
net = Encoder()
hook = net.encoder[-2].register_forward_hook(hook_fn)
# call hook.remove() to remove

so after each forward pass you will find View’s output in view_output.

Hope that helps!

2 Likes

Thank you so much for this, it looks like what I need. But how do I actually use it. I tried:

net(x)
print(view_output)

but I still get None.

Oops, please look at my edit, it should work better :slightly_smiling_face:

u get the 8*N dimensional vector successfully ?
Does u complete it by using the hook.remove() ?

I just ran the code with batch size of 1 and N=32:

>>> view_output.shape
torch.Size([1, 256])

so it looks good.

hook.remove() is in case you don’t want to use the hook run anymore.