How to get the tensor of neurons for each hidden layer?

I’d like to visualize the distribution of the neurons associated to each hidden layer in my deep neural network. In principle, given the tensor I can easily do it through tensorboard.
For example if I was interested in the distribution of weight between layers I could do something like that:

for n, w in self.model.named_parameters():
    self.writer.add_histogram('Weights connection {}'.format(n), w, global_step=t)

where n is the name of the module, w is the corresponding tensor of weight and t is the index at which our summary is saved.

Is there an equivalent way to iterate through hidden layers calling tensors associated with the neurons’ state of each layer?

You can add an if condition like if 'Linear' in n to filter specific layers.

The problem is that model.named_parameters() iterate over the weight of my network; for example (correct me if I misunderstood your point) Linear refers to a fully connected matrix of weights that map a given hidden layer’s vector to the next one. I’m- only interested in these vectors (not in the weights’ set), i.e. the set of numbers that are multiplied with weights during the forward propagation

You can directly retrive the weights or bias vectors by weight or bias attribute like this:

import torch.nn as nn


class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 3, 1),
            nn.ReLU()
        )
        self.fc = nn.Linear(10, 10)


model = MyModel()
conv_weight = model.cnn[0].weight
fc_bias = model.fc.bias

Thank you for your comment!
I’m not interested in the weights; I report below a simple example to shows more clear what is my objective:

class Net(nn.Module, NetVariables):
    def __init__(self, params):
        
        self.params = params.copy()
        
        nn.Module.__init__(self)
        NetVariables.__init__(self, self.params)        
        
        # number of hidden nodes in each layer (32)
        hidden_1 = 32
        hidden_2 = 32

        # linear layer (784 -> hidden_1)
        if  (params['Dataset']=='MNIST'):
            self.fc1 = nn.Linear(28*28, hidden_1)
        #NOTE: images in MNIST dataset are b&w photos with 28*28 pixels, CIFAR10 instead contain with colored photos (3 colour channel) with 32*32 pixels 
        elif  (params['Dataset']=='CIFAR10'):
            self.fc1 = nn.Linear(32*32*3, hidden_1)
        # linear layer (n_hidden -> hidden_2)
        self.fc2 = nn.Linear(hidden_1,hidden_2)
        # linear layer (n_hidden -> 10)
        self.fc3 = nn.Linear(hidden_2, self.num_classes)
        
        #weights initialization (this step can also be put below in a separate def)
        nn.init.kaiming_normal_(self.fc1.weight, mode='fan_in', nonlinearity='relu')
        nn.init.kaiming_normal_(self.fc2.weight, mode='fan_in', nonlinearity='relu')
        nn.init.xavier_normal_(self.fc3.weight)
        #initialize the bias to 0
        nn.init.constant_(self.fc3.bias, 0)
    #I return from the forward a dictionary to get the output after each layer not only the last one
    def forward(self,x):
        
        outs = {}
        # flatten image input
        if  (self.params['Dataset']=='MNIST'):
            x = x.view(-1,28*28)
        #NOTE: images in MNIST dataset are b&w photos with 28*28 pixels, CIFAR10 instead contain with colored photos (3 colour channel) with 32*32 pixels 
        elif  (self.params['Dataset']=='CIFAR10'):
            x = x.view(-1,32*32*3)
        # add hidden layer, with relu activation function
        Fc1 = F.relu(self.fc1(x))
        
        outs['l1'] = Fc1
       
        Fc2 = torch.tanh(self.fc2(Fc1))
        
        
        outs['l2'] = Fc2
        
        # add output layer
        Out = self.fc3(Fc2)
        
        outs['out'] = Out
        
        return outs

In the above example the forward propagation return not only the output layer but a dictionary that contain also all the previous ones (hidden layer outputs); this is all I want. On the other hand with more complex networks the number of hidden layer increases, so I’d like to not do it manually with the dictionary approach above; I’d like to iterate over the hidden layers and get the corresponding output tensors (outs['l1'] and outs['l2'] in the above example)

I get your point. You can achieve you purpose by register_forward_hook, here is example code:

from functools import partial

import torch as t
from torchvision.models import resnet50

model = resnet50()
data = t.randn(1, 3, 224, 224)

layer_output = {}


def get_all_layers(net):
    def hook_fn(m, i, o, n=""):
        layer_output[n] = o

    for name, layer in net.named_modules():
        if hasattr(layer, "_module") and layer._module:
            get_all_layers(layer)
        elif hasattr(layer, "_parameters") and layer._parameters:
            # it's a non sequential. Register a hook
            layer.register_forward_hook(partial(hook_fn, n=name))


get_all_layers(model)
model(data)
for k, v in layer_output.items():
    print(f"the output shape of layer: {k} is {v.shape}")

You can refer to pytorch-hooks-gradient-clipping-debugging for deeper interpretation.

1 Like