When should I use nn.ModuleList and when should I use nn.Sequential?

Oh I see, yeah that helps a lot. Thanks!

The reason you are seeing NotImplimentedError is that ModuleList is really meant to used like a list. There is not forward operation defined on it, and I doubt there will be in future.

6 Likes

Sorry for replying to such an old thread, but I found an interesting use-case where nn.ModuleList kinda saved me. Basically, if you have a module with a variable number of layers

import numpy as np
import torch as tr
import torch.nn as nn

def getNumParams(params):
	numParams, numTrainable = 0, 0
	for param in params:
		npParamCount = np.prod(param.data.shape)
		numParams += npParamCount
		if param.requires_grad:
			numTrainable += npParamCount
	return numParams, numTrainable

# Using list
class Module1(nn.Module):
	def __init__(self, dIn, dOut, numLayers):
		super(Module1, self).__init__()
		self.layers = []
		for i in range(numLayers - 1):
			self.layers.append(nn.Conv2d(in_channels=dIn, out_channels=dIn, kernel_size=1))
		self.layers.append(nn.Conv2d(in_channels=dIn, out_channels=dOut, kernel_size=1))

	def forward(self, x):
		y = x
		for i in range(len(self.layers)):
			y = self.layers[i](y)
		return y

# Using nn.ModuleList
class Module2(nn.Module):
	def __init__(self, dIn, dOut, numLayers):
		super(Module2, self).__init__()
		self.layers = nn.ModuleList()
		for i in range(numLayers - 1):
			self.layers.append(nn.Conv2d(in_channels=dIn, out_channels=dIn, kernel_size=1))
		self.layers.append(nn.Conv2d(in_channels=dIn, out_channels=dOut, kernel_size=1))

	def forward(self, x):
		y = x
		for i in range(len(self.layers)):
			y = self.layers[i](y)
		return y

def main():
	x = tr.randn(1, 7, 30, 30)

	module1 = Module1(dIn=7, dOut=13, numLayers=10)
	y1 = module1(x)
	print(y1.shape) # (1, 13, 30, 30)
	print(getNumParams(module1.parameters())) # Prints (0, 0)

	module2 = Module2(dIn=7, dOut=13, numLayers=10)
	y2 = module2(x)
	print(getNumParams(module2.parameters())) # Print (608, 608)
	print(y2.shape) # (1, 13, 30, 30)

if __name__ == "__main__":
	main()

Just my 2c on how this feature saved me, as my code checks for params count when loading network weights :slight_smile:

9 Likes

nn.Module comes in handy while writing many DL model. For example when you are trying to code Maxout Network as defined in the paper [Maxout Networks] (https://arxiv.org/pdf/1302.4389.pdf).

class maxout_mlp(nn.Module):
    def __init__(self, num_units=2):
        super(maxout_mlp,self).__init__()
        self.fc1_list= nn.ModuleList()
        self.fc2_list= nn.ModuleList()
        
        for _ in range(num_units):
            self.fc1_list.append(nn.Linear(784,1024))
            self.fc2_list.append(nn.Linear(1024,10))
        
    def forward(self,x):
        x= x.view(-1,784)
        x= self.maxout(x,self.fc1_list)
        x= F.dropout(x, training= self.training)
        x= self.maxout(x,self.fc2_list)
        return F.log_softmax(x)
    
    def maxout(self,x, layer_list):
        max_output= layer_list[0](x) # pass x to first unit in layer1
        for _, layer in enumerate(layer_list, start=1):
            max_output= torch.max(layer(x),max_output)
        return max_output
4 Likes

Thanks for posting this!
My understanding is that the 2 versions only differ when you look at the parameter count, but the results for y1 and y2 would be the same (assuming same seed for random initialization), correct?

So the only difference between Sequential and ModuleList is that, Sequential does not has a append method which does not allowed you to add layers in a for loop.

Stupid question, why use ModuleList instead of a normal python list? is it so that parameters are included in the .parameters() iterator?

6 Likes

Exactly! If you use a plain python list, the parameters won’t be registered properly and you can’t pass them to your optimizer using model.parameters().

20 Likes

The question is already answered several times but I want to share my experience which may help you to think on a practical case

Firstly, I want to mention again nn.Sequential stores some layers which has already implemented forward method where layers are passed in a cascaded way. The point is, you dont always want layers to be cascaded. In my case what I need was concatenating output of CNN Layers having different kernel sizes.

Here is the paper, I tried to implement “Convolutional Neural Networks for Sentence Classification”. For the starting point I found an implementation on github

class CNNSentence(nn.Module):
        def __init__(self, args, data, vectors):
                super(CNNSentence, self).__init__()
                ...
                for filter_size in args.FILTER_SIZES:
                       conv = nn.Conv1d(self.in_channels,
                                        args.num_feature_maps,
                                        args.word_dim * filter_size,
                                        stride=args.word_dim)
                       setattr(self, 'conv_' + str(filter_size), conv)
                ...

        def forward(self, batch):
                ...
                conv_result = [
                        F.max_pool1d(F.relu(getattr(self, 'conv_' + str(filter_size))(conv_in)),
                                     seq_len - filter_size + 1).view(-1, self.args.num_feature_maps)
                        for filter_size in self.args.FILTER_SIZES]

                out = torch.cat(conv_result, 1)
                ...

Source

However, I skipped setting the convolutional layers as attribute while rewriting the model. Later, while transferring the network to gpu, I realized that convolutional layer is not in my network since I got an error. My failed code is below:

class CNN_Sentence(nn.Module):
    def __init__(self, ..., ngram_filter_sizes=[3, 4, 5], ...):
        super(CNN_Sentence, self).__init__()
        ...
        self.convs = []
        for ngram_filter in ngram_filter_sizes:
            conv = nn.Conv1d(embedding_size,
                             conv_out_filter,
                             ngram_filter).to(args.device)
            self.convs.append(conv)
        ...

    def forward(self, batch):
        ...
        x = []
        for conv in self.convs:
            conv_out = conv(batch)
            max_pool_kernel = conv_out.shape[2]
            conv_out = F.max_pool1d(F.relu(conv_out),
                                    max_pool_kernel)
            x.append(conv_out.view(bath_size, -1))
       ...

I think it is a good example why we need nn.ModuleList and why it is different than nn.Sequential

You can find the entry point discussion of nn.ModuleList also, which helped me to discover the class nn.ModuleList
https://discuss.pytorch.org/t/list-of-nn-module-in-a-nn-module/219

4 Likes

Hi,
I was trying to implement some the existing Pytorch model from python to C++ using libtorch API.
Have couple of Blocks in model declared as nn.ModuleList().
How can implement the same in C++.
Below is dummy snippet for the same.

self.initial_layer = DummyConv(in_channels, growth_ratenum_layers,dilation=1,
kernel_size=kernel_size, pad=pad, x)
self.layers = nn.ModuleList()
for i in range(1,num_layers):
self.layers.add_module('layer%s' % i, DummyConv(growth_rate, growth_rate(num_layers-i), dilation=i,
kernel_size=kernel_size, pad=i,)

def forward(self, x):
    out = self.initial_layer(x)
    for i, layer in enumerate(self.layers):
        out[:,(i+1)*self.growth_rate:] += layer(out[:,i*self.growth_rate:(i+1)*self.growth_rate].contiguous())

    return out[:,-self.growth_rate:]`

The question was when, and the answer may be in cases when you need dynamic module structure and you don’t know in advance how it will look.
The original example provided is fair:

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])

    def forward(self, x):
        # ModuleList can act as an iterable, or be indexed using ints
        for i, l in enumerate(self.linears):
            x = self.linears[i // 2](x) + l(x)
        return x

This doesn’t have anything with dynamic graph creation, which PyTorch also do.

If I am not wrong, there must be at least one forward method in PyTorch, so module list will be part of class which will evaluate in that class forward.

Using a class derived from nn.Module some call also a functional approach.

nn.ModuleList can be child of nn.Sequentional and in that case inside sequentional we need to have a class that aggregates it.

Thanks for your great explanation

I wonder if it work if I put multi nn.Sequential blocks into nn.ModueleList?

It would assume it should work. Are you seeing any issues with this approach?

I guess nn.Sequential will be registered parameters to my optimizer like ModuleList?

Yes, both nn.Sequential and nn.ModuleList which are assigned to an nn.Module instance as an attribute will register the parameters.
Also, both methods provide the .parameters() method, which can be used to pass the parameters to an optimizer.

Let me know, if you encounter any issues.

4 Likes

Great explanation and easy to understand!

I came from TF for a few days and i can be wrong on this as i am still getting started on Pytorch. I had the same doubt on this too and realized that if you need a real fine tune in your model than you probably need the nn.ModuleList. Many Yolo implementations (in Pytorch) runs appending layer by layer with nn.ModuleList. I would just take care as most of the times will be cleaner to just use nn.Sequential.

With nn.ModuleList:

class MyModule(nn.Module):
    def __init__(self, layers_count=10):
    	super(MyModule, self).__init__()

		self.layers = nn.ModuleList([nn.Linear(10, 10) for _ in range(layers_count)])

    def forward(self, x):
    	for layer in self.layers:
    		x = layer(x)

    	return x     

Now with nn.Sequential:

class MyModule(nn.Module):
    def __init__(self, layers_count=10):
    	super(MyModule, self).__init__()

		layers = [nn.Linear(10, 10) for _ in range(layers_count)]
		self.layers = nn.Sequential(*layers)

    def forward(self, x):
    	return self.layers(x)

I think both codes do the same but the 2nd is less verbose as you need to iterate on ModuleList…

1 Like

nn.Flatten is available in Pytorch already, no need to write it for yourself.

I love you! I love you! Best explanation of nn.ModuleList anywhere on the interwebs