Maybe the easiest way to understand this is with a little example.
Here I have defined two Modules that perform functional stuff. The SimpleBlock
multiplies by 2
. The MainModel
uses this SimpleBlock
in two different places and also performs some functional stuff (multiply by 3
and add 3
).
This can, of course, be something more complex like using layers, batch normalization, flatten or whatever you want to do.
class SimpleBlock(torch.nn.Module):
def __init__(self):
super().__init__()
def functional_stuff(self, x):
return x * 2
def forward(self, x):
# Functional - Multiply by 2
x = self.functional_stuff(x)
return x
class MainModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.block_0 = SimpleBlock()
self.block_1 = SimpleBlock()
def functional_stuff(self, x):
# This is inside a nn.Module
x = self.block_0(x)
# This is functional
x = x * 3
return x
def forward(self, x):
x = self.functional_stuff(x)
# Functional - Add 3
x = x + 3
# This is a nn.Module
# The functional stuff indside this module is done
# even if you extract the children with Sequential
x = self.block_1(x)
return x
We now use the MainModel
and extract the children using nn.Sequential
.
input = torch.ones(2, 2)
model = MainModel()
seq_model = torch.nn.Sequential(*list(model.children()))
print(model(input))
print(seq_model(input))
If we now see the output generated, we can see that they do NOT yield the same result.
# Output
# Model
tensor([[18., 18.],
[18., 18.]])
# Sequential extraction
tensor([[4., 4.],
[4., 4.]])
The full model gives the correct output.
([1 * 2 ] * 3 + 3) * 2 = ([2] * 3 + 3) * 2 = (6 + 3)*2 = (9)*2=18
However the nn.Sequential
did not multiply by 3
or added the 3
.
([1*2])*2 = 2*2 = 4
This is because the functional stuff is wrapped inside this nn.Module
called SimpleBlock
. But the stuff on the forward
method of our MainModel
is not. We can even use nn.Modules
that are in another function other than forward
and call them from inside forward
. But if they are not nn.Modules
(or a container like nn.Sqeuential
or others) they will not work.
In your example of ConvNeXt
, this F.layer_norm
is inside the class LayerNorm
used there. This will work.
Also in the forward method of ConvNeXt
we call another method. Inside this method. However, everything here is either a nn.Module
or inside a container like nn.Sequential
and nn.ModuleList
.
So if we did something like this
mod = torchvision.models.convnext_tiny(pretrained=True)
seq_mod = torch.nn.Sequential(*list(mod.children()))
mod.eval()
seq_mod.eval()
input = torch.rand(1, 3, 300, 300)
print(torch.all(mod(input) == seq_mod(input)))
We get tensor(True)
as the output.
If there was any other form of functional API on the forward of ConvNeXt
(not the children but ONLY on ConvNeXt
) then it would not work the same.
In the example given, there is a flatten
function, which will break the pipeline if missing.
Hope this helps and makes it a bit clearer