I get your point. You can achieve you purpose by register_forward_hook
, here is example code:
from functools import partial
import torch as t
from torchvision.models import resnet50
model = resnet50()
data = t.randn(1, 3, 224, 224)
layer_output = {}
def get_all_layers(net):
def hook_fn(m, i, o, n=""):
layer_output[n] = o
for name, layer in net.named_modules():
if hasattr(layer, "_module") and layer._module:
get_all_layers(layer)
elif hasattr(layer, "_parameters") and layer._parameters:
# it's a non sequential. Register a hook
layer.register_forward_hook(partial(hook_fn, n=name))
get_all_layers(model)
model(data)
for k, v in layer_output.items():
print(f"the output shape of layer: {k} is {v.shape}")
You can refer to pytorch-hooks-gradient-clipping-debugging for deeper interpretation.