Alright, as requested I’m including additional details of my problem. I’m trying to reconstruct simple 1d
vectors with 8 dimensions (plus a 1 dimension label vector) with a conditional variational autoencoder model:
class CVAE(BaseModel):
def __init__(self, in_size, target_size):
super().__init__()
#the input is concatenated to the target property
self.encoder = nn.Sequential(
nn.Linear(in_size + target_size, 512),
nn.ReLU(),
nn.LayerNorm(512),
nn.Linear(512,256),
nn.ReLU(),
nn.LayerNorm(256),
nn.Linear(256,128),
nn.ReLU(),
nn.LayerNorm(128),
nn.Linear(128,latent_size*2),
)
self.decoder = nn.Sequential(
nn.Linear(latent_size + target_size, 128),
nn.ReLU(),
nn.LayerNorm(128),
nn.Linear(128,256),
nn.ReLU(),
nn.LayerNorm(256),
nn.Linear(256,512),
nn.ReLU(),
nn.LayerNorm(512),
nn.Linear(512,in_size),
)
def reparameterise(self, mu, logvar):
if self.training:
std = logvar.mul(0.5).exp_()
eps = std.data.new(std.size()).normal_()
return eps.mul(std).add_(mu)
else:
return mu
def encode(self, x,cond):
x = torch.cat([x,cond],dim=1)
mu_logvar = self.encoder(x).view(-1, 2, latent_size)
mu = mu_logvar[:, 0, :]
logvar = mu_logvar[:, 1, :]
return mu, logvar
def decode(self, z):
return self.decoder(z)
def forward(self,x,cond):
mu, logvar = self.encode(x,cond)
z = self.reparameterise(mu, logvar)
z = torch.cat([z,cond],dim=1)
x_hat = self.decode(z)
part1 = nn.ReLU()(x_hat[:,:4])
part2 = nn.Softmax(dim=1)(x_hat[:,4:])
x_hat = torch.cat([part1,part2],dim=1)
return x_hat, mu, logvar
Additional method/class that I use to calculate recon_loss
are the following:
def fit(self, dataloader, optimizer, criterion):
self.train()
running_loss = 0.0
for i, data in tqdm(enumerate(dataloader), total=int(len(dataloader.dataset)/dataloader.batch_size)):
optimizer.zero_grad()
reconstruction, mu, logvar = self.forward(data[0],data[1])
loss = criterion(reconstruction, data[0])
loss = vae_loss(loss, mu, logvar)
running_loss += loss.item()
loss.backward()
optimizer.step()
train_loss = running_loss/len(dataloader.dataset)
return train_loss
------------------------------------
class CustomLoss(_Loss):
def __init__(self):
super().__init__()
def forward(self, input, target):
""" loss function called at runtime """
# Class 1
class_1_loss = nn.MSELoss()(
input[:,:4],
target[:,:4])
# Class 2
loss_cos = nn.CosineEmbeddingLoss()
class_2_loss = loss_cos(
input[:,4:],
target[:,4:],torch.ones(input.shape[0]))
return class_1_loss + class_2_loss
Given that we take a batch of data in DataLoader
and calculating recon_loss:
for data in data_loader:
break
criterion= CustomLoss()
recon, mu, logvar = model(data[0],data[1])
recon_loss = criterion(recon, data[0])
where
recon_loss
Out[18]: tensor(1281.3059, grad_fn=<AddBackward0>)
next, in `vae_loss() we compute
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(),dim=1)
KLD
Out[23]:
tensor([10.0125, 9.2514, 8.6930, 11.5028, 6.2790, 11.4213, 11.1161, 8.7398,
8.7730, 11.4224, 10.7966, 8.1383, 7.9794, 8.4988, 8.1407, 7.8110,
7.8110, 7.9266, 7.5567, 7.7040, 9.9480, 7.6494, 10.5388, 9.4197,
10.1101, 8.4504, 8.7865, 8.8465, 10.7121, 8.4979, 8.6045, 8.3936,
8.5527, 7.2297, 8.4476, 10.8846, 10.4522, 11.5168, 7.7652, 9.0135,
8.9415, 10.0520, 9.4360, 10.9811, 9.3636, 10.0256, 7.2043, 7.6451,
7.7242, 7.6227, 10.8506, 8.6256, 7.5179, 10.1579, 8.7330, 10.3774,
9.8613, 8.9309, 10.0385, 9.0261, 9.4356, 9.6858, 10.1660, 8.4929,
8.7968, 7.5675, 8.2790, 8.7619, 9.4661, 10.3707, 10.6991, 10.1204,
11.2257, 11.0965, 8.7320, 10.7721, 9.4106, 9.6219, 8.5730, 12.0483,
6.5600, 10.1521, 10.1500, 10.1347, 10.1345, 10.1373, 8.4438, 6.3176,
8.5711, 8.7008, 9.7572, 11.4712, 10.6697, 11.0056, 10.6899, 10.3529,
8.3172, 10.5426, 8.4198, 10.1392, 11.1788, 8.7461, 8.5806, 7.9725,
8.7498, 10.9897, 11.0135, 8.6260, 9.3328, 9.3445, 9.6393, 6.3542,
9.4578, 9.5768, 11.3704, 9.4054, 11.4350, 9.7938, 9.7804, 10.1121,
11.1820, 6.0557, 10.4342, 7.9478, 6.3706, 11.0652, 9.2101, 8.6107],
grad_fn=<MulBackward0>)
where I’ve just realized that if we do not compute a torch.mean
over these values we won’t be able to perform the final operation into vae_loss()
: recon_loss + KLD
…
So taking the mean could be reasonable here but I still don’t understand if it makes sense from a more conceptual standopoint.