<< extracting stats from activated layers - training >>

Hey!
I would like to make sure I am doing this right and make some improvements.
It’s kinda lengthy, but I want to be as specific as possible.

The train and test code is in this repo, I am making some minor changes, but nothing major. Just adding stuff to these examples.

>>> import torch
>>> torch.__version__
'1.0.0'

I am running experiments on CIFAR-10 using various architectures, for this example let’s use VGG11.
I would like to record stats of all activated layers while training, every epoch or every n-th epoch.

Result of a below print statement.

print(list(net.named_modules()))
[('', VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace)
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU(inplace)
    (11): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): ReLU(inplace)
    (14): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (15): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (16): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (17): ReLU(inplace)
    (18): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (19): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (20): ReLU(inplace)
    (21): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (22): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (23): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (24): ReLU(inplace)
    (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (26): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (27): ReLU(inplace)
    (28): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (29): AvgPool2d(kernel_size=1, stride=1, padding=0)
  )
  (classifier): Linear(in_features=512, out_features=10, bias=True)
)), ('features', Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace)
  (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (6): ReLU(inplace)
  (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (8): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (9): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (10): ReLU(inplace)
  (11): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (12): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (13): ReLU(inplace)
  (14): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (15): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (16): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (17): ReLU(inplace)
  (18): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (19): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (20): ReLU(inplace)
  (21): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (22): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (23): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (24): ReLU(inplace)
  (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (26): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (27): ReLU(inplace)
  (28): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (29): AvgPool2d(kernel_size=1, stride=1, padding=0)
)), 
('features.0', Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))), 
('features.1', BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), 
('features.2', ReLU(inplace)), 
('features.3', MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1,ceil_mode=False)), 
('features.4', Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))), 
('features.5', BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), 
('features.6', ReLU(inplace)), 
('features.7', MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)), 
('features.8', Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))), 
('features.9', BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), 
('features.10', ReLU(inplace)), 
('features.11', Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))), ('features.12', BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), 
('features.13', ReLU(inplace)), 
('features.14', MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)), 
('features.15', Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))), ('features.16', BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), 
('features.17', ReLU(inplace)), 
('features.18', Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))), ('features.19', BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), 
('features.20', ReLU(inplace)), 
('features.21', MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)), 
('features.22', Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))), 
('features.23', BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), 
('features.24', ReLU(inplace)), 
('features.25', Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))), 
('features.26', BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), 
('features.27', ReLU(inplace)), 
('features.28', MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)), 
('features.29', AvgPool2d(kernel_size=1, stride=1, padding=0)), 
('classifier', Linear(in_features=512, out_features=10, bias=True))]

Using hooks that you’re mentioned here I am able extract stats from the layers like this:

activations = defaultdict(list)
def save_activation(name):
    def hook(model, input, output):
        activations[name].append((epoch,
        np.float(output.min().detach().data), 
        np.float(output.max().detach().data), 
        np.float(output.mean().detach().data), 
        np.float(output.std().detach().data)))
    return hook
net.module.register_forward_hook(save_activation('features.2'))
net.module.register_forward_hook(save_activation('features.6'))
net.module.register_forward_hook(save_activation('features.10'))
net.module.register_forward_hook(save_activation('features.13'))

Here is litany of my questions :slight_smile::

  1. Given the output of net.named_modules, am I getting the right indices for activated layers? (2): ReLU(inplace)
    this is how I am accessing hooked information
    net.module.register_forward_hook(save_activation('features.2'))
    If this is not right, please advise on how can I access activated layers explicitly and efficiently.

  2. The dictionary collects stats of all batches - 391 to be exact, this is kind of spammy. How could I access the net stats of activated layers at the epoch end?

  3. Is default dict the right data structure for collecting all that information?

I’ve seen fast.ai did some work on that, but I don’t really understand how to use it, so I would prefer to stick to this version of training network.
The best analogy that comes to mind is tf.summary in TensorFlow. That collects info to visualise in TensorBoard. To be clear, I want to do more than visualisation, hence stats needed.

Looking forward to hearing your thoughts on this.

  1. You would have to call register_forward_hook on the particular layer, not the complete model/module. Based on your code it looks like you are working with a DataParallel model, since you are accessing the .module attribute. If that’s the case, you would need to address each layer separately:
net.module.features[2].register_forward_hook(save_activation('features.2'))
net.module.features[6].register_forward_hook(save_activation('features.6'))
net.module.features[10].register_forward_hook(save_activation('features.10'))
net.module.features[13].register_forward_hook(save_activation('features.13'))

The name passed to save_activation is just for your own convenience. :wink:

  1. You could just register the hook for the last epoch, pass some flag to save_activations if you are currently dealing with the last epoch, or just overwrite the activations using:
def save_activation(name):
    def hook(model, input, output):
        activations[name] = [(epoch,
        np.float(output.min().detach().data), 
        np.float(output.max().detach().data), 
        np.float(output.mean().detach().data), 
        np.float(output.std().detach().data))]
    return hook
  1. It depends on your use case etc. I personally like to use dicts as you can make sure which activation you are currently working with. You could use whatever data structure fits your needs, but I’m quite paranoid about using plain lists as I cannot be sure to look at the right layer (e.g. what if the list was somehow extended). :wink:

If you want to visualize some stats of your model using Tensorboard, have a look at this Ignite example, where some handlers are used to do this.

1 Like

Awesome, thanks for you help with this.

Just one more stupid question:
Any layer described like below, refers to the activation layer and index at which could be accessed right?

(2): ReLU(inplace)

In above example

net.module.register_forward_hook(save_activation('features.2'))

I understand that correctly, right? :slight_smile:

I know stupid, but really want to be sure :stuck_out_tongue:

Thanks once more @ptrblck

The (2) corresponds to the layer name.
If you just pass the layers into an nn.Sequential module, the name will correspond to the index in this sequential block.
However, you could also pass custom names:

model = nn.Sequential(OrderedDict([
          ('conv1', nn.Conv2d(1,20,5)),
          ('relu1', nn.ReLU()),
          ('conv2', nn.Conv2d(20,64,5)),
          ('relu2', nn.ReLU())
        ]))
    
print(model)
> Sequential(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (relu1): ReLU()
  (conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  (relu2): ReLU()
)
print(id(model.conv1) == id(model[0]))
> True

In your example code, you just register the hook to the complete model, so that input will correspond to the input data and output to the result of your last layer.

Apologies my bad, pasted wrong thing :slight_smile:
Ment to paste your way

net.module.features[2].register_forward_hook(save_activation('features.2'))

like that I am accessing tensor of this particular activated layer right? So I can compute stats.

You are indexing into the nn.Sequential module, which is defined as features.
print(net.module.features[2]) will give you that particular layer, not the activation tensor.
This layer will register the hook, which will get the activation tensor in def hook.

and that hooked tensor will correspond to the specified layer (accessed by indexing or naming)?

So, by printing DataParallel model like above list(net.named_modules()), I will know indices of all layers including activations, which will be named with the following pattern (in this example, could be named differently elsewhere):

(2): ReLU(inplace)
(6): ReLU(inplace)
...
(27): ReLU(inplace)

Then by registering the hooks with right indices, using your method, I can compute stats for activated layers.
Apologies for layman lingo, but I am not as technical as you are ;).

And big thanks for being patient with my questions :slight_smile:

Yes, the tensor in your hook function will correspond to the specified layer (accessed by indexing or naming).

Yes, if the activations are created as modules. The alternative way would be to use the functional API for the activation functions, e.g. as done in DenseNet.
If you encounter such a model, you might want to override the forward method and store the activations manually or replace them with the corresponding modules.

Yes, your code looks fine to achieve this.

Haha, no need to apologize. Your post are perfectly clear! :wink:

1 Like

Thank you, good human! :wink:
For all clarifications, now time to get to work :smiley: