I am trying to reimplement this paper on beta VAEs. The dataset I am working with is dSprites, so the images are white sprites on a black background. Currently, when I save the loaded image data and my reconstruction data into a png file, the loaded image looks fine but my reconstruction image is just black. I am wondering why this is happening and how to train the decoder to generate similar images to the dSprites images. This prompted me to check if my parameters were changing as the model trains, which I think it does. Then I thought about checking my reconstruction since that is what I am saving as an image. I realized that my final layer from the decoder, my reconstruction, was learning negative parameters which does not make sense for a pixel value.
I am training my model as follows, where gamma and C are hyperparameters I set to 1000 and 1.
def train(model, dataloader, gamma, C):
model.train()
running_loss = 0.0
for i, data in tqdm(enumerate(dataloader),
total=int(len(train_set)/dataloader.batch_size)):
data = data.unsqueeze(1)
data = data.to(device)
optimizer.zero_grad()
reconstruction, mu, logvar = model.forward(data)
loss = model.final_loss(reconstruction, data, mu, logvar, gamma, C)
running_loss += loss.item()
loss.backward()
optimizer.step()
train_loss = running_loss/len(dataloader.dataset)
return train_loss
My validation code follows a similar structure, but it saves the reconstructed image as a png file.
def validate(model, dataloader, gamma, C, epoch):
model.eval()
running_loss = 0.0
with torch.no_grad():
for i, data in tqdm(enumerate(dataloader),
total=int(len(test_set)/dataloader.batch_size)):
data = data.unsqueeze(1)
data = data.to(device)
reconstruction, mu, logvar = model.forward(data)
loss = model.final_loss(reconstruction, data, mu, logvar, gamma, C)
running_loss += loss.item()
if i == 0:
num_rows = min(data.size(0), 8)
both = torch.cat((data.view(batch_size, 1, 64, 64)[:8],
reconstruction.view(batch_size, 1, 64, 64)[:8]))
save_image(both.cpu(), f"outputs/output{gamma}-{C}-{epoch}.png",
nrow=num_rows)
val_loss = running_loss/len(dataloader.dataset)
return val_loss
My model is fairly straightforwards with layers specified in the disentangling paper.
class ReImp(nn.Module):
"""Reimplmentation of paper"""
def __init__(self):
super(ReImp, self).__init__()
# Encoder
self.enc_convLayer = nn.Sequential(
nn.Conv2d(1, 32, 4, stride=2, padding=1), # B, 32, 32, 32
nn.ReLU(),
nn.Conv2d(32, 32, 4, stride=2, padding=1), # B, 32, 16, 16
nn.ReLU(),
nn.Conv2d(32, 32, 4, stride=2, padding=1), # B, 32, 8, 8
nn.ReLU(),
nn.Conv2d(32, 32, 4, stride=2, padding=1), # B, 32, 4, 4
nn.ReLU()
)
self.enc_linLayer = nn.Sequential(
nn.Linear(4 * 4 * 32, 256), # B, 256
nn.ReLU(),
nn.Linear(256, 256), # B, 256
nn.ReLU(),
nn.Linear(256, 20) # B, 20
)
# Decoder
self.dec_linLayer = nn.Sequential(
nn.Linear(10, 256), # B, 256
nn.ReLU(),
nn.Linear(256, 256), # B, 256
nn.ReLU(),
nn.Linear(256, 4 * 4 * 32), # B, 512
nn.ReLU()
)
self.dec_convLayer = nn.Sequential(
nn.ConvTranspose2d(32, 32, 4, stride=2, padding=1), # B, 32, 8, 8
nn.ReLU(),
nn.ConvTranspose2d(32, 32, 4, stride=2,
padding=1), # B, 32, 16, 16
nn.ReLU(),
nn.ConvTranspose2d(32, 32, 4, stride=2,
padding=1), # B, 32, 32, 32
nn.ReLU(),
nn.ConvTranspose2d(32, 1, 4, stride=2, padding=1) # B, 1, 64, 64
)
def encode(self, x):
x = self.enc_convLayer(x) # Encode to B, 32, 4, 4
x = x.view(-1, 4 * 4 * 32) # B, 512
x = self.enc_linLayer(x)
return x
def decode(self, sample):
x = self.dec_linLayer(sample) # B, 512
x = x.view(-1, 32, 4, 4) # B, 32, 4, 4
x = self.dec_convLayer(x)
return x
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
sample = mu + std * eps
return sample
def forward(self, x):
x_size = x.size()
x = self.encode(x)
mu, logvar = x[:, :10], x[:, 10:] # split latent layer in half
sample = self.reparameterize(mu, logvar)
x = self.decode(sample)
x = x.view(x_size)
return x, mu, logvar
def final_loss(self, reconstruction, x, mu, logvar, gamma, C):
criterion = nn.BCEWithLogitsLoss(reduction='sum')
bce_loss = criterion(reconstruction, x)
kld = gamma * \
torch.abs((-0.5 * torch.sum(1 + logvar -
mu.pow(2) - logvar.exp(), dim=1)).mean(dim=0) - C)
return bce_loss + kld
I also wondered if loading uint_8 from numpy as floats would change my resulting deconstruction, but it seems like the dataset is all 0’s, black, and 1’s, white.
root = os.path.abspath(os.getcwd(
) + '/dsprites-dataset-master/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz')
data = np.load(root)
data = torch.from_numpy(data['imgs']).float()
class CustomDataset(Dataset):
"""DSprites Dataset"""
def __init__(self, data):
self.data = data
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.data.size(0)
dataset = CustomDataset(data)
I understand this is a chunk of code and appreciate people for reading.