Is there the train model store the output in pytorch?

Hi there,

In torch when we train our model for a image the model can store the output in model. Is it also possible to store the output in pytorch model after train the image?

In torch we can access the output after train model as like given example …

out = model:forward(input)

suppose there is 10 modules in the model and I want to access the output of 5th layer in this case

layer_5 = model.modules[5].output

Like this is it possible in pytorch too. If not how will do the same in pytorch.

Thanks

If you would like to store the output of a certain (internal) layer, you could use forward hooks as described here.

Hi @ptrblck how can I store each layer (internal or external) out put as they calculate forward pass. I am not using name of each layer I want to store according to indexing of of each layer.

In my model there is layer in side layer also present.

Example : -

Sequential(
  (0): Conv2d()
  (1): BatchNorm2d()
  (2): ReLU()
  (3): Sequential(
    (0): ConcatTable(
      (0): Sequential(
        (0): BatchNorm2d()
        (1): ReLU()
        (2): Conv2d()
        (3): BatchNorm2d()
        (4): ReLU()
        (5): Conv2d()
        (6): BatchNorm2d()
        (7): ReLU()
        (8): Conv2d()
      )
      (1): Sequential()
    )
    (1): CAddTable()
  )

Note : - I removed the input , output and kernel etc info.
Thanks

You could iterate all modules and add the forward hooks using their name.
Here is a small example:

activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook


model = nn.Sequential(
    nn.Conv2d(1, 1, 3, 1, 1),
    nn.BatchNorm2d(1),
    nn.ReLU(),
    nn.Sequential(
        nn.Conv2d(1, 1, 3, 1, 1),
        nn.BatchNorm2d(1),
        nn.ReLU()
    ),
    nn.Conv2d(1, 1, 3, 1, 1),
    nn.BatchNorm2d(1),
    nn.ReLU()
)
    
for name, module in model.named_modules():
    module.register_forward_hook(get_activation(name))

model(torch.randn(1, 1, 4, 4))

for key in activation:
    print(key, activation[key])

Note, that I removed ConcatTable and CAddTable, as they are classes from Torch7, which is not under active development anymore. I would therefore recommend to switch to PyTorch. :wink:
Have a look at the website for install instructions, if you haven’t already installed it.

1 Like