I write one class for my model. I want to use into another model, and want to delete three last sequential of that.
class SpermModel(nn.Module):
def __init__(self):
super(SpermModel, self).__init__()
self.cnn_layers = model
self.modi = torch.nn.Sequential(*(list(self.cnn_layers.children())[:-1]))
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(25088, 1024),
nn.ReLU(),
nn.Linear(1024, 2)
)
self.fc1 = nn.Sequential(
nn.Flatten(),
nn.Linear(25088, 1024),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(1024, 2)
)
self.fc2 = nn.Sequential(
nn.Flatten(),
nn.Linear(25088, 1024),
nn.ReLU(),
nn.Linear(1024, 2)
)
def forward(self, item):
out = self.modi(item)
out1 = self.fc(out)
out2 = self.fc1(out)
out3 = self.fc2(out)
return out1, out2, out3
I used self.modi = torch.nn.Sequential(*(list(self.cnn_layers.children())[:-1]))
for that, but it just delete one sequential layer.