Accessing intermediate data in nn.Sequential

I’d like to access the output of a module inside an nn.Sequential. It could be that I have a big CNN wrapped in nn.Sequential, and I’d like to concat the output of several layers. Or it could be that I’m just checking the intermediate layer outputs to debug stuff.

Doing this used to be quite easy in lua torch because modules cache their output in module.output. But pytorch apparently doesn’t do that anymore. Is there a way I can access intermediate layers’ output elegantly?

Currently I just inherit from nn.Sequential and in forward() I cache the outputs from specific modules I want. I’m not sure if this screws up memory management since it looks like the intermediate values need to be re-allocated at every network pass?

1 Like

Subclassing sequential and overriding forward (just like you are doing now) is the right approach.

We wanted to get rid of Sequential itself, but we kept it as a convenience container.

There is no screwing up of memory management because of subclassing sequential (or caching outputs), it works out fine.

1 Like

What would be the consequence if we cache every module’s output by adding the following to nn.Module’s __call__():
> self.output = var

Would this increase memory usage by prolonging the lifetime of output? I’m very curious about how intermediate values are allocated and deallocated.

yes if you do that, all output variables will be cached and you will run out of memory (because in pytorch output buffers are not reused but recreated)

1 Like

You won’t necessarily run out of memory, because if you overwrite them at every forward, you’ll be keeping at most one additional copy of the graph, that’s likely to have its buffers freed if only you called backward on it.

About the original question, you might also want to take a look at register_forward_hook Module method.

1 Like