Hi, I am using a model with the code below:
class TextCNN(nn.Module):
def __init__(self, nb_words, embed_dim, num_filters, num_classes):
super(TextCNN, self).__init__()
self.nb_words = nb_words
self.embed_dim = embed_dim
self.num_filters = num_filters
self.embedding = nn.Embedding(nb_words, embed_dim)
self.conv = nn.Conv1d(embed_dim, num_filters, kernel_size=7, stride=1)
self.fc1 = nn.Linear(num_filters, 32)
self.fc2 = nn.Linear(32, num_classes)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.embedding(x)
x = x.permute(0, 2, 1)
x = F.dropout(x)
x = self.conv(x).permute(0, 2, 1)
x = F.relu(x).max(1)
x = self.fc1(x[0])
x = F.relu(x)
x = F.dropout(x)
x = self.fc2(x)
return self.sigmoid(x)
Now, I want to take this model, and modify the last layer. For example, say, initially the num_classes is 3, the output is 3 dimensional. However, I want to use 5 classes by removing the last fc2 layer and re-adding fc2 with num_classes as 5.
I tried doing nn.Sequential(*list(model.children())[:-1]) or :-2 and added the new fc2 layer. However, the problem is the forward function has some custom operations such as permute. So, the new model gives an error.
What i did is the following class, that takes in the older class and then modifies the last layer.
class NewCNN(nn.Module):
def __init__(self, textCNN_model, num_classes_in_output):
super(NewCNN, self).__init__()
self.num_filters = textCNN_model.num_filters
self.embed_dim = textCNN_model.embed_dim
self.nb_words = textCNN_model.nb_words
self.embedding = nn.Embedding(self.nb_words, self.embed_dim)
self.conv = nn.Conv1d(self.embed_dim, self.num_filters, kernel_size=7, stride=1)
self.fc1 = nn.Linear(self.num_filters, 32)
self.fc2 = nn.Linear(32, num_classes_in_output)
self.sigmoid = nn.Sigmoid()
self.embedding.weight = nn.Parameter(textCNN_model.embedding.weight)
self.conv.weight = nn.Parameter(textCNN_model.conv.weight)
self.fc1.weight = nn.Parameter(textCNN_model.fc1.weight)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.embedding(x)
x = x.permute(0, 2, 1)
x = F.dropout(x)
x = self.conv(x).permute(0, 2, 1)
x = F.relu(x).max(1)
x = self.fc1(x[0])
x = F.relu(x)
x = F.dropout(x)
x = self.fc2(x)
return self.sigmoid(x)
I am wondering if there is a better way to do this?