I was wondering whether there is an easy option to access the first layer in a custom non-Sequential CNN. When the network is constructed the ‘sequential way’, you can just use: network[0]. Is there anything similar to that?
In custom modules the “first layer” annotation might be misleading as seen in this example:
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv2 = nn.Conv2d(3, 3, 3, 1, 1)
self.conv1 = nn.Conv2d(3, 3, 3, 1, 1)
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.conv2(x)
return x
model = MyModel()
print(model)
print(next(model.named_children()))
model.named_children() as well as the print(model) statement will return conv2 as the “first” layer, but in the end the usage of the modules in the forward defines the order of execution.
Anyway, next(model.named_children()) might work for you.
Note that finding the “first layer” also depends what you mean by layer, as nn.Modules might themselves contain more modules as seen here:
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv2 = models.resnet18()
self.conv1 = nn.Conv2d(3, 3, 3, 1, 1)
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.conv2(x)
return x
model = MyModel()
print(model)
print(next(model.named_children()))
Oh I see, I did not take this into consideration. In my case, the term first means the first convolutional layer. And since most of the times (at least with the models I am working with) the first layer is a convolutional layer I mismatched the terms.
Is there a similar way to access the order of execution in the forward method then?
Not a beautiful approach, but you could register forward hooks for each layer and print its name when it’s called.
Based on the output you would see which layer was called first in this execution.
Note that this is also not a bulletproof approach, as your forward might have conditions, loops, etc. which can change during the runtime.
I guess for the majority of models you might be find to iterate model.named_modules() or .named_children(), filter using if isinstance(module, nn.Conv2d) and print the first occurrence.