I am trying to reuse a model (transfer learning). I want to remove the last layer (and will later add another layer)
My existing model is
class Model_TL_A(nn.Module):
def __init__(self):
super(Model_TL_A, self).__init__()
self.flatten = nn.Flatten()
self.layers = nn.ModuleList()
current_dim = 784
for n_hidden in (300, 100, 50, 50, 50):
self.layers.append(nn.Linear(current_dim, n_hidden))
self.layers.append(nn.SELU())
current_dim = n_hidden
self.linear1 = nn.Linear(current_dim, 8)
def forward(self, X):
out = self.flatten(X)
for layer in self.layers[:-1]:
out = layer(out)
out = self.linear1(out)
return out
model_tl_a = Model_TL_A()
from torchsummary import summary
summary(model_tl_a, (1,28,28))
Here is my other model
class Model_TL_B_on_A(nn.Module):
def __init__(self, Model_TL_A):
super(Model_TL_B_on_A, self).__init__()
self.model_tl_a = Model_TL_A()
# modules = list(self.model_tl_a.children())[:-1]
# self.model = nn.Sequential(*modules)
def forward(self, X):
out = self.model_tl_a(X)
# out = self.model(X)
return out
It works fine (since I have commented the lines that remove the last layer.
model_tl_b_on_a = Model_TL_B_on_A(Model_TL_A)
from torchsummary import summary
summary(model_tl_b_on_a, (1,28,28))
However, if I uncomment the two lines in Model_TL_B_on_A, and the last line in forward
, I get the following error
TypeError: forward() takes 1 positional argument but 2 were given
Can someone help me out?
Thanks