I have a pre-trained model and want to add additional layers anywhere in the model. Currently I use the following approach where I have a model consisting only of linear layers where I add ReLU activation functions after each linear layer.
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self, num_in, num_hidden, num_out):
super().__init__()
self.linear1 = nn.Linear(num_in, num_hidden)
self.linear2 = nn.Linear(num_hidden, num_hidden)
self.sequential_module = nn.Sequential(
nn.Linear(num_hidden, num_hidden),
nn.Linear(num_hidden, num_hidden),
)
self.linear3 = nn.Linear(num_hidden, num_out)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
x = self.sequential_module(x)
x = self.linear3(x)
return x
def add_relu(module):
for child_name, child in module.named_children():
if isinstance(child, torch.nn.Linear):
sequential_layer = torch.nn.Sequential(child, torch.nn.ReLU(inplace=True))
setattr(module, child_name, sequential_layer)
else:
add_relu(child)
if __name__ == "__main__":
num_in = 2
num_hidden = 16
num_out = 4
batch_size = 4
model = Model(num_in, num_hidden, num_out)
print(model)
add_relu(model)
print(model)
This approach, however, leads to the problem that the models attributes have the wrong name as can be seen in the following output.
Model(
(linear1): Sequential(
(0): Linear(in_features=2, out_features=16, bias=True)
(1): ReLU(inplace=True)
)
(linear2): Sequential(
(0): Linear(in_features=16, out_features=16, bias=True)
(1): ReLU(inplace=True)
)
(sequential_module): Sequential(
(0): Sequential(
(0): Linear(in_features=16, out_features=16, bias=True)
(1): ReLU(inplace=True)
)
(1): Sequential(
(0): Linear(in_features=16, out_features=16, bias=True)
(1): ReLU(inplace=True)
)
)
(linear3): Sequential(
(0): Linear(in_features=16, out_features=4, bias=True)
(1): ReLU(inplace=True)
)
(softmax): Softmax(dim=1)
)
Here, model.linear1
is actually a Sequential layer. How could this be avoided and is there a better way to add layers anywhere in an existing model since I read that using setattr
is not a good style.