Hello,
Slightly new here. I’m trying to use a convolutional VAE to generate images, however my loss remains constant throughout training. I’ve tried using different architectures, and different optimizers. I’ve also tried removing the noise and KLD term as to make the VAE into a normal autoencoder, but the average loss still stays constant. i.e., no change.
One-step training code:
def train_one_step(self, data):
data = data[0]
self.optim.zero_grad()
pred, mu, log = self.vae(data)
loss = losses(pred, data, mu, log, beta=self.beta)
print(loss.item())
self.optim.step()
self.losses.append(loss.item())
VAE code with Block used and loss function used.
class Block(nn.Module):
def __init__(self, inchan, outchan, resdrop, activation=None, kernel_size=5):
super(Block, self).__init__()
self.conv = nn.Conv2d(inchan, outchan, kernel_size)
self.forget_param = nn.Parameter(torch.ones(1))
self.shortcut = nn.Conv2d(inchan, outchan, kernel_size)
self.shortcut_param = nn.Parameter(torch.ones(1))
self.down = nn.Upsample((resdrop, resdrop))
self.norm1 = nn.InstanceNorm2d(inchan)
self.norm1_param = nn.Parameter(torch.zeros(1))
self.norm2 = nn.BatchNorm2d(outchan)
self.norm2_param = nn.Parameter(torch.zeros(1))
self.norm3 = nn.LayerNorm((resdrop, resdrop))
self.norm3_param = nn.Parameter(torch.zeros(1))
self.activation = nn.PReLU(outchan) if activation is None else activation
def forward(self, x):
y = x
y = self.conv(y)
y = self.down(y)
y = y + (self.shortcut_param * self.down(self.shortcut(x)))
y = self.activation(y)
return y
class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
self.conv1 = Block(3, 3, 256)
self.conv2 = Block(3, 3, 128)
self.conv3 = Block(3, 3, 64)
self.conv3_log = Block(3, 3, 32)
self.conv3_mu = Block(3, 3, 32)
self.conv4 = Block(3, 3, 64)
self.conv5 = Block(3, 3, 128)
self.conv6 = Block(3, 3, 256)
self.conv7 = Block(3, 3, 256)
self.conv8 = Block(3, 3, 256, torch.tanh)
def encode(self, x):
encoding = self.conv3(self.conv2(self.conv1(x)))
return self.conv3_mu(encoding), self.conv3_log(encoding)
def reparameterize(self, mu, log):
# eps = 0 * torch.randn_like(log)
return log + mu
def decode(self, x):
decoding = self.conv8(self.conv7(self.conv6(self.conv5(self.conv4(x)))))
return decoding
def forward(self, x):
encode_mu, encode_log = self.encode(x)
x = self.reparameterize(encode_mu, encode_log)
x = self.decode(x)
return x, encode_mu, encode_log
def losses(y_pred, y_true, mu, log, beta=1):
mse = F.binary_cross_entropy(torch.sigmoid(y_pred), torch.sigmoid(y_true), reduction='sum')
kld = -0.5 * torch.sum(1 + log - mu.pow(2) - log.exp())
return mse #+ beta * (kld)
Any advice or tips on training VAEs would be much appreciated.