Hi,
with this model:
features = 16*16
# define a simple linear VAE
class LinearAECenc(nn.Module):
def __init__(self):
super(LinearAECenc, self).__init__()
# encoder
self.enc1 = nn.Linear(in_features=V*H, out_features=512*3//2)
self.enc12 = nn.Linear(in_features=512*3//2, out_features=512)
self.enc2 = nn.Linear(in_features=512, out_features=features)
def forward(self, x):
# encoding
x = F.relu(self.enc1(x))
x = F.relu(self.enc12(x))
z = F.relu(self.enc2(x))
return z
class LinearAECdec(nn.Module):
def __init__(self):
super(LinearAECdec, self).__init__()
# decoder
self.dec1 = nn.Linear(in_features=features, out_features=512)
self.dec12 = nn.Linear(in_features=512, out_features=512*3//2)
self.dec2 = nn.Linear(in_features=512*3//2, out_features=V*H)
def forward(self, z):
# decoding
y = F.relu(self.dec1(z))
y = F.relu(self.dec12(y))
reconstruction = torch.sigmoid(self.dec2(y))
return reconstruction
class AEC(nn.Module):
def __init__(self):
super(AEC, self).__init__()
self.enc=LinearAECenc().to(device0)
self.dec=LinearAECdec().to(device1)
def forward(self,x):
z=self.enc.forward(x.to(device0))
rec=self.dec.forward(z.to(device1))
return rec
and this trinings function
model = AEC() #.to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.MSELoss(reduction='sum')
def fit(model, dataloader):
model.train()
running_loss = 0.0
for i, data in enumerate(dataloader):
data = data.view(data.size(0), -1)
optimizer.zero_grad()
reconstruction = model(data)
loss = criterion(reconstruction, data.to(device1))
loss.backward()
optimizer.step()
running_loss += loss.item()
train_loss = running_loss/len(dataloader.dataset)
return train_loss
optimizer = optim.Adam(model.parameters(), lr=lr)
epochs=5000
for epoch in (p:=tqdm(range(epochs))):
train_epoch_loss = fit(model, train_s3load)
train_loss.append(train_epoch_loss)
p.set_description(f"Loss: {train_loss[-1]:.4f} ")
Only the parameters on device1 are updated. Can you give me hint, what to do to update all
parameters?
Thanks in advance