When to use add_module function?

The module class contains a function add_module to initialize submodules. I am a bit confused what the purpose of this function is.

Most people initialize submodules as simple member variables of the supermodule (like in example 1). Is there any case where it is advantageous to initialize supmodules using the add_module (as in example 2)? What is the recommended way of initializing submodules?

Is there any difference between example 1 and example 2?

Example 1: Initialize submodules as Member Variabels

class Net(nn.Module):
    def __init__(self):
	super(Net, self).__init__()
	self.conv1   = nn.Conv2d(3, 16, 5, padding=2)
	self.pool    = nn.MaxPool2d(2, 2)
	self.dropout = nn.Dropout2d(p=0.5)
	self.conv2   = nn.Conv2d(16, 16, 5, padding=2)
	self.conv3   = nn.Conv2d(16, 400, 11, padding=5)
	self.conv4   = nn.Conv2d(400, 200, 1)
	self.conv5   = nn.Conv2d(200, 1, 1)

Example 2: Initialize submodules using add_module

class Net(nn.Module):
    def __init__(self):
    super(Net, self).__init__()
    self.add_module("conv1", Conv2d(3, 16, 5, padding=2))
    self.add_module("pool", MaxPool2d(2, 2))
    self.add_module("dropout", Dropout2d(p=0.5))
    self.add_module("conv2", Conv2d(16, 16, 5, padding=2))
    self.add_module("conv3", Conv2d(16, 400, 11, padding=5))
    self.add_module("conv4", Conv2d(400, 200, 1))
    self.add_module("conv5", Conv2d(200, 1, 1))
11 Likes

I think example 1 and example 2 do the same thing. The setattr method of nn.Module doesn’t call add_module directly but it does something similar.

In general, you won’t need to call add_module. One potential use case is the following:

class Net(nn.Module):
    def __init__(self):
	super(Net, self).__init__()
        modules = [...]  # some list of modules
        for module in modules:
            self.add_module(...)
17 Likes

it just looks simple ~

Well, I think example 2 provides more simplicity when you need to add the same module using for loop.
Let me know if there are other use cases!

for l in range(3):
     self.add_module(f'networkA_{l}'}
     self.add_module(f'networkB_{l}'}
1 Like