I want to replace layers in a PyTorch model and want to use the .named_modules()
method. However, my code throws the following error:
File "C:\Users\samuel\Anaconda3\envs\lab\lib\site-packages\torch\nn\modules\module.py", line 1615, in named_modules
for name, module in self._modules.items():
RuntimeError: OrderedDict mutated during iteration
Here is my code
import torch
import torch.nn as nn
class MLP(nn.Module):
def __init__(self, num_in, num_hidden, num_out, seed=None):
super().__init__()
dropout_rate = 0.2
self.dropout1 = nn.Dropout(p=dropout_rate)
self.dropout2 = nn.Dropout(p=dropout_rate)
self.linear1 = nn.Linear(num_in, num_hidden)
self.linear2 = nn.Linear(num_hidden, num_hidden)
self.linear3 = nn.Linear(num_hidden, num_out)
self.sequential_module = nn.Sequential(
nn.Linear(num_hidden, num_hidden),
nn.ReLU(),
nn.Dropout(p=0.2),
)
self.relu1 = nn.ReLU()
self.relu2 = nn.ReLU()
self.relu3 = nn.ReLU()
self.softmax = nn.Softmax(dim=1)
def forward(self, x):
x = self.relu1(self.linear1(x))
x = self.dropout1(x)
x = self.relu2(self.linear2(x))
x = self.dropout2(x)
x = self.sequential_module(x)
x = self.softmax(self.linear3(x))
return x
def replace_layer(model):
for module_name, module in model.named_modules():
if isinstance(module, nn.ReLU):
setattr(model, module_name, nn.Softplus())
if __name__ == "__main__":
print(torch.__version__)
model = MLP(num_in=2, num_hidden=16, num_out=3)
print(model)
print("\n")
replace_layer(model)
print(model)
Any ideas what I can do to solve my problem?