I want to replace activation functions and other layers in my model. I used the idea from this question: https://discuss.pytorch.org/t/how-to-replace-a-layer-with-own-custom-variant/43586
However, the code replaces all activation functions but not those inside the sequential module. Any idea what I do wrong? How can I replace the ReLU also in the sequential module?
import torch
import torch.nn as nn
class MLP(nn.Module):
def __init__(self, num_in, num_hidden, num_out, seed=None):
super().__init__()
dropout_rate = 0.2
self.dropout1 = nn.Dropout(p=dropout_rate)
self.dropout2 = nn.Dropout(p=dropout_rate)
self.linear1 = nn.Linear(num_in, num_hidden)
self.linear2 = nn.Linear(num_hidden, num_hidden)
self.linear3 = nn.Linear(num_hidden, num_out)
self.sequential_module = nn.Sequential(
nn.Linear(num_hidden, num_hidden),
nn.ReLU(),
nn.Dropout(p=0.2),
)
self.relu1 = nn.ReLU()
self.relu2 = nn.ReLU()
def forward(self, x):
x = self.relu1(self.linear1(x))
x = self.dropout1(x)
x = self.relu2(self.linear2(x))
x = self.dropout2(x)
x = self.sequential_module(x)
x = F.softmax(self.linear3(x), dim=1)
return x
def replace_layer(module):
for attr_str in dir(module):
target_attr = getattr(module, attr_str)
if type(target_attr) == torch.nn.ReLU:
new_bn = torch.nn.Sigmoid()
setattr(module, attr_str, new_bn)
for name, immediate_child_module in module.named_children():
replace_layer(immediate_child_module)
if __name__ == "__main__":
model = MLP(num_in=2, num_hidden=16, num_out=3)
print(model)
replace_layer(model)
print(model)
The code above produces the following output:
MLP(
(dropout1): Dropout(p=0.2, inplace=False)
(dropout2): Dropout(p=0.2, inplace=False)
(linear1): Linear(in_features=2, out_features=16, bias=True)
(linear2): Linear(in_features=16, out_features=16, bias=True)
(linear3): Linear(in_features=16, out_features=3, bias=True)
(sequential_module): Sequential(
(0): Linear(in_features=16, out_features=16, bias=True)
(1): ReLU()
(2): Dropout(p=0.2, inplace=False)
)
(relu1): ReLU()
(relu2): ReLU()
)
MLP(
(dropout1): Dropout(p=0.2, inplace=False)
(dropout2): Dropout(p=0.2, inplace=False)
(linear1): Linear(in_features=2, out_features=16, bias=True)
(linear2): Linear(in_features=16, out_features=16, bias=True)
(linear3): Linear(in_features=16, out_features=3, bias=True)
(sequential_module): Sequential(
(0): Linear(in_features=16, out_features=16, bias=True)
(1): ReLU()
(2): Dropout(p=0.2, inplace=False)
)
(relu1): Sigmoid()
(relu2): Sigmoid()
)