Wrapping submodules in Sequential

Is the statement in the link below true?

According to the response, functional calls will be missing when using nn.Sequential.
However, I saw many repos use submodules with functional api in sequential, for example, the official convnext implementation:

They used functional api (F.layer_norm) inside submodule (class Block), and used nn.Sequential to wrap all the blocks.

Maybe the easiest way to understand this is with a little example.

Here I have defined two Modules that perform functional stuff. The SimpleBlock multiplies by 2. The MainModel uses this SimpleBlock in two different places and also performs some functional stuff (multiply by 3 and add 3).

This can, of course, be something more complex like using layers, batch normalization, flatten or whatever you want to do.

class SimpleBlock(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def functional_stuff(self, x):
        return x * 2
    
    def forward(self, x):
        # Functional - Multiply by 2
        x = self.functional_stuff(x)
        return x

class MainModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.block_0 = SimpleBlock()
        self.block_1 = SimpleBlock()
        
    def functional_stuff(self, x):
        # This is inside a nn.Module
        x = self.block_0(x)

        # This is functional
        x = x * 3
        return x

    def forward(self, x):
        x = self.functional_stuff(x)

        # Functional - Add 3
        x = x + 3 

        # This is a nn.Module
        # The functional stuff indside this module is done
        # even if you extract the children with Sequential
        x = self.block_1(x)
        return x

We now use the MainModel and extract the children using nn.Sequential.

input = torch.ones(2, 2)

model = MainModel()
seq_model = torch.nn.Sequential(*list(model.children()))

print(model(input))
print(seq_model(input))

If we now see the output generated, we can see that they do NOT yield the same result.

# Output

# Model
tensor([[18., 18.],
        [18., 18.]])
# Sequential extraction
tensor([[4., 4.],
        [4., 4.]])

The full model gives the correct output.
([1 * 2 ] * 3 + 3) * 2 = ([2] * 3 + 3) * 2 = (6 + 3)*2 = (9)*2=18

However the nn.Sequential did not multiply by 3 or added the 3.
([1*2])*2 = 2*2 = 4

This is because the functional stuff is wrapped inside this nn.Module called SimpleBlock. But the stuff on the forward method of our MainModel is not. We can even use nn.Modules that are in another function other than forward and call them from inside forward. But if they are not nn.Modules (or a container like nn.Sqeuential or others) they will not work.

In your example of ConvNeXt, this F.layer_norm is inside the class LayerNorm used there. This will work.

Also in the forward method of ConvNeXt we call another method. Inside this method. However, everything here is either a nn.Module or inside a container like nn.Sequential and nn.ModuleList.

So if we did something like this

mod = torchvision.models.convnext_tiny(pretrained=True)
seq_mod = torch.nn.Sequential(*list(mod.children()))

mod.eval()
seq_mod.eval()

input = torch.rand(1, 3, 300, 300)

print(torch.all(mod(input) == seq_mod(input)))

We get tensor(True) as the output.
If there was any other form of functional API on the forward of ConvNeXt (not the children but ONLY on ConvNeXt) then it would not work the same.

In the example given, there is a flatten function, which will break the pipeline if missing.

Hope this helps and makes it a bit clearer :smile:

1 Like