Replacing layers in model with named_modules()

I want to replace layers in a PyTorch model and want to use the .named_modules() method. However, my code throws the following error:

File "C:\Users\samuel\Anaconda3\envs\lab\lib\site-packages\torch\nn\modules\module.py", line 1615, in named_modules
    for name, module in self._modules.items():
RuntimeError: OrderedDict mutated during iteration

Here is my code

import torch
import torch.nn as nn


class MLP(nn.Module):

    def __init__(self, num_in, num_hidden, num_out, seed=None):
        super().__init__()
        dropout_rate = 0.2
        self.dropout1 = nn.Dropout(p=dropout_rate)
        self.dropout2 = nn.Dropout(p=dropout_rate)

        self.linear1 = nn.Linear(num_in, num_hidden)
        self.linear2 = nn.Linear(num_hidden, num_hidden)
        self.linear3 = nn.Linear(num_hidden, num_out)

        self.sequential_module = nn.Sequential(
            nn.Linear(num_hidden, num_hidden),
            nn.ReLU(),
            nn.Dropout(p=0.2),
        )

        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.relu3 = nn.ReLU()

        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.relu1(self.linear1(x))
        x = self.dropout1(x)

        x = self.relu2(self.linear2(x))
        x = self.dropout2(x)

        x = self.sequential_module(x)

        x = self.softmax(self.linear3(x))

        return x


def replace_layer(model):
    for module_name, module in model.named_modules():
        if isinstance(module, nn.ReLU):
            setattr(model, module_name, nn.Softplus())


if __name__ == "__main__":
    print(torch.__version__)

    model = MLP(num_in=2, num_hidden=16, num_out=3)
    print(model)
    print("\n")

    replace_layer(model)
    print(model)

Any ideas what I can do to solve my problem?

1 Like

Try to delay the replacement, as you are currently manipulating the model while also iterating its named_modules. To do so, store the module_name in a separate list and perform the replacement after the initial loop.

2 Likes

This might help.

import torch
import torch.nn as nn


class MLP(nn.Module):

    def __init__(self, num_in, num_hidden, num_out, seed=None):
        super().__init__()
        dropout_rate = 0.2
        self.dropout1 = nn.Dropout(p=dropout_rate)
        self.dropout2 = nn.Dropout(p=dropout_rate)

        self.linear1 = nn.Linear(num_in, num_hidden)
        self.linear2 = nn.Linear(num_hidden, num_hidden)
        self.linear3 = nn.Linear(num_hidden, num_out)

        self.sequential_module = nn.Sequential(
            nn.Linear(num_hidden, num_hidden),
            nn.ReLU(),
            nn.Dropout(p=0.2),
        )

        self.relu = nn.ReLU()

        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.relu(self.linear1(x))
        x = self.dropout1(x)

        x = self.relu(self.linear2(x))
        x = self.dropout2(x)

        x = self.sequential_module(x)

        x = self.softmax(self.linear3(x))

        return x



print(torch.__version__)

model = MLP(num_in=2, num_hidden=16, num_out=3)
print(model)
print("\n")

relu_lst = [k.split('.') for k, m in model.named_modules(remove_duplicate=False) if isinstance(m, nn.ReLU)]
relu_lst
for *parent, k in relu_lst:
    if len(parent) == 0:
        model.relu = nn.Softplus()
    else:
        model.get_submodule('.'.join(parent))[int(k)] = nn.Softplus()

print(model)