Hi,
I have this problem: I would like to save the inputs of all nn.Module
in my top-level module. For example, let’s assume I have this very simple model and that – given a batch – I want to save the inputs to conv1
, conv2
, fc1
and fc2
:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(...)
self.conv2 = nn.Conv2d(...)
self.fc1 = nn.Linear(...)
self.fc2 = nn.Linear(...)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = x.view(-1, ...)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
I can certainly do some like this but it’s very ugly and not very flexible.
def forward(self, x):
self.input_to_conv1 = x # <====
x = self.conv1(x)
x = F.relu(x)
self.input_to_conv2 = x # <====
x = self.conv2(x)
x = F.relu(x)
x = x.view(-1, ...)
self.input_to_fc1 = x # <====
x = self.fc1(x)
x = F.relu(x)
self.input_to_fc21 = x # <====
x = self.fc2(x)
return F.log_softmax(x, dim=1)
Moreover, imagine that I don’t have access to how this modules are built: I know that they are made of nn.Module
s but I have no idea of the names and/or structure of the model.
Therefore the question is: Is there a “pytorch-ish” way of saving the inputs of all nn.Module
children for later use?