VAE reconstruction learns negative floats for images

I am trying to reimplement this paper on beta VAEs. The dataset I am working with is dSprites, so the images are white sprites on a black background. Currently, when I save the loaded image data and my reconstruction data into a png file, the loaded image looks fine but my reconstruction image is just black. I am wondering why this is happening and how to train the decoder to generate similar images to the dSprites images. This prompted me to check if my parameters were changing as the model trains, which I think it does. Then I thought about checking my reconstruction since that is what I am saving as an image. I realized that my final layer from the decoder, my reconstruction, was learning negative parameters which does not make sense for a pixel value.

I am training my model as follows, where gamma and C are hyperparameters I set to 1000 and 1.

def train(model, dataloader, gamma, C):
    model.train()
    running_loss = 0.0
    for i, data in tqdm(enumerate(dataloader),
                        total=int(len(train_set)/dataloader.batch_size)):
        data = data.unsqueeze(1)
        data = data.to(device)
        optimizer.zero_grad()
        reconstruction, mu, logvar = model.forward(data)
        loss = model.final_loss(reconstruction, data, mu, logvar, gamma, C)
        running_loss += loss.item()
        loss.backward()
        optimizer.step()
    train_loss = running_loss/len(dataloader.dataset)
    return train_loss

My validation code follows a similar structure, but it saves the reconstructed image as a png file.

def validate(model, dataloader, gamma, C, epoch):
    model.eval()
    running_loss = 0.0
    with torch.no_grad():
        for i, data in tqdm(enumerate(dataloader),
                            total=int(len(test_set)/dataloader.batch_size)):
            data = data.unsqueeze(1)
            data = data.to(device)
            reconstruction, mu, logvar = model.forward(data)
            loss = model.final_loss(reconstruction, data, mu, logvar, gamma, C)

            running_loss += loss.item()

            if i == 0:
                num_rows = min(data.size(0), 8)
                both = torch.cat((data.view(batch_size, 1, 64, 64)[:8],
                                  reconstruction.view(batch_size, 1, 64, 64)[:8]))
                save_image(both.cpu(), f"outputs/output{gamma}-{C}-{epoch}.png",
                           nrow=num_rows)

    val_loss = running_loss/len(dataloader.dataset)
    return val_loss

My model is fairly straightforwards with layers specified in the disentangling paper.

class ReImp(nn.Module):
    """Reimplmentation of paper"""

    def __init__(self):
        super(ReImp, self).__init__()

        # Encoder

        self.enc_convLayer = nn.Sequential(
            nn.Conv2d(1, 32, 4, stride=2, padding=1),   # B, 32, 32, 32
            nn.ReLU(),
            nn.Conv2d(32, 32, 4, stride=2, padding=1),  # B, 32, 16, 16
            nn.ReLU(),
            nn.Conv2d(32, 32, 4, stride=2, padding=1),  # B, 32, 8, 8
            nn.ReLU(),
            nn.Conv2d(32, 32, 4, stride=2, padding=1),  # B, 32, 4, 4
            nn.ReLU()
        )

        self.enc_linLayer = nn.Sequential(
            nn.Linear(4 * 4 * 32, 256),                 # B, 256
            nn.ReLU(),
            nn.Linear(256, 256),                        # B, 256
            nn.ReLU(),
            nn.Linear(256, 20)                          # B, 20
        )

        # Decoder
        self.dec_linLayer = nn.Sequential(
            nn.Linear(10, 256),  # B, 256
            nn.ReLU(),
            nn.Linear(256, 256),  # B, 256
            nn.ReLU(),
            nn.Linear(256, 4 * 4 * 32),  # B, 512
            nn.ReLU()
        )

        self.dec_convLayer = nn.Sequential(
            nn.ConvTranspose2d(32, 32, 4, stride=2, padding=1),  # B, 32, 8, 8
            nn.ReLU(),
            nn.ConvTranspose2d(32, 32, 4, stride=2,
                               padding=1),                      # B, 32, 16, 16
            nn.ReLU(),
            nn.ConvTranspose2d(32, 32, 4, stride=2,
                               padding=1),                      # B, 32, 32, 32

            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, 4, stride=2, padding=1)  # B, 1, 64, 64
        )

    def encode(self, x):
        x = self.enc_convLayer(x)                       # Encode to B, 32, 4, 4
        x = x.view(-1, 4 * 4 * 32)                      # B, 512
        x = self.enc_linLayer(x)
        return x

    def decode(self, sample):
        x = self.dec_linLayer(sample)  # B, 512
        x = x.view(-1, 32, 4, 4)  # B, 32, 4, 4
        x = self.dec_convLayer(x)
        return x

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        sample = mu + std * eps
        return sample

    def forward(self, x):
        x_size = x.size()
        x = self.encode(x)
        mu, logvar = x[:, :10], x[:, 10:]  # split latent layer in half
        sample = self.reparameterize(mu, logvar)
        x = self.decode(sample)
        x = x.view(x_size)
        return x, mu, logvar

    def final_loss(self, reconstruction, x, mu, logvar, gamma, C):
        criterion = nn.BCEWithLogitsLoss(reduction='sum')
        bce_loss = criterion(reconstruction, x)
        kld = gamma * \
            torch.abs((-0.5 * torch.sum(1 + logvar -
                                        mu.pow(2) - logvar.exp(), dim=1)).mean(dim=0) - C)
        return bce_loss + kld

I also wondered if loading uint_8 from numpy as floats would change my resulting deconstruction, but it seems like the dataset is all 0’s, black, and 1’s, white.

root = os.path.abspath(os.getcwd(
) + '/dsprites-dataset-master/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz')
data = np.load(root)
data = torch.from_numpy(data['imgs']).float()


class CustomDataset(Dataset):
    """DSprites Dataset"""

    def __init__(self, data):
        self.data = data

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.data.size(0)


dataset = CustomDataset(data)

I understand this is a chunk of code and appreciate people for reading.

nn.BCEWithLogitsLoss expects the model output to contain logits, which are not bounded to a specific range.
Based on your code it seems that this loss calculation is correct (I haven’t checked the kld loss).
However, since your target image contains zeros and ones, I would assume that you would need to apply torch.sigmoid on the model output before storing the predictions.
Otherwise negative logits would correspond to a low probability (close to 0) while positive logits would correspond to a high probability (close to 1).
You might also or alternatively apply a threshold to create only zeros and ones as the predicted image.

1 Like

Thank you for answering!

I figured out I was not using enough training data, only around 10,000 images. After training on 100,000 images, my output started outputting sprite-like objects. Using a sigmoid on the decoder output, not as part of the network, makes sense for forcing the output in between 0 and 1. However I think since the weights of the output layer are positive after further training, adding a sigmoid afterwards would change the weights but not the output.