Let’s say I have a pretrained autoencoder, and I just need the pretrained encoder as a part of a new model.
class autoencoder(nn.Module):
def __init__(self):
super(autoencoder, self).__init__()
self.encoder = nn.Sequential(
nn.Linear(28 * 28, 128),
nn.ReLU(True),
nn.Linear(128, 64),
nn.ReLU(True), nn.Linear(64, 12), nn.ReLU(True), nn.Linear(12, 3))
self.decoder = nn.Sequential(
nn.Linear(3, 12),
nn.ReLU(True),
nn.Linear(12, 64),
nn.ReLU(True),
nn.Linear(64, 128),
nn.ReLU(True), nn.Linear(128, 28 * 28), nn.Tanh())
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
def get_encoder(self):
return self.encoder
class autoencoder2(nn.Module):
def __init__(self, encoder):
super(autoencoder, self).__init__()
self.encoder = encoder
self.decoder = nn.Sequential(
nn.Linear(3, 12),
nn.ReLU(True),
nn.Linear(12, 64),
nn.ReLU(True),
nn.Linear(64, 128),
nn.ReLU(True), nn.Linear(128, 28 * 28), nn.Tanh())
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
def get_encoder(self):
return self.encoder
pretrn_AE = autoencoder()
checkpoint = torch.load('AE.pt', map_location='cpu')
pretrn_AE.load_state_dict(checkpoint['model'])
pretrn_encoder = pretrn_AE.get_encoder()
new_AE = autoencoder2(pretrn_encoder)
While training new_AE
, will this cause warning or error saying that some trainable weights are not used?(the decoder in pretrn_AE).
To play safe, I want to have a clean copy of the pretrained encoder to feed into new_AE
.
And while training the new_AE
, it has nothing to do with the pretrn_AE
.
Any advice?
Thanks.