I think you want to use the forward_hook for this. You can register a hook so that at every forward call, the registered hooks will call a function where you can save their output. https://pytorch.org/docs/stable/nn.html?highlight=forward_hook#torch.nn.Module.register_forward_hook
A working example here:
import torch
import torch.nn as nn
import torch.nn.functional as F
def print_tensor_props(self, input, output):
print(input[0].shape, end=' => ')
print(output.shape)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 4, 3, padding=1)
self.conv2 = nn.Conv2d(4, 8, 3, padding=1)
self.fc1 = nn.Linear(512, 256)
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = x.view(-1, 512)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
net = Net()
net.conv1.register_forward_hook(print_tensor_props)
net.conv2.register_forward_hook(print_tensor_props)
net.fc1.register_forward_hook(print_tensor_props)
net.fc2.register_forward_hook(print_tensor_props)
x = torch.randn(5, 3, 8, 8)
print('x :: ', x.shape)
out = net(x)
which will print the following:
x :: torch.Size([5, 3, 8, 8])
torch.Size([5, 3, 8, 8]) => torch.Size([5, 4, 8, 8])
torch.Size([5, 4, 8, 8]) => torch.Size([5, 8, 8, 8])
torch.Size([5, 512]) => torch.Size([5, 256])
torch.Size([5, 256]) => torch.Size([5, 10])
Furthermore, if you want to save the inputs to a list, you can do as follows:
g_list = []
def save_tensor(self, input, output):
g_list.append(input)
net = Net()
net.conv1.register_forward_hook(save_tensor)
net.conv2.register_forward_hook(save_tensor)
net.fc1.register_forward_hook(save_tensor)
net.fc2.register_forward_hook(save_tensor)
x = torch.randn(5, 3, 8, 8)
out = net(x)
for i,x in enumerate(g_list):
print("Input to layer {} has shape ".format(i), x[0].shape)
Input to layer 0 has shape torch.Size([5, 3, 8, 8])
Input to layer 1 has shape torch.Size([5, 4, 8, 8])
Input to layer 2 has shape torch.Size([5, 512])
Input to layer 3 has shape torch.Size([5, 256])