I’m in the process of implementing a variational autoencoder on CIFAR10. I’ve come across a weird problem. After having written the model code, I attempted training it and saw that the model didn’t learn anything at all.
Then I used pdb to see where this problem came from and saw that the loss was just nan. Then I checked how the loss was calculated and saw that the reconstruction loss was the source of the nan problem. To summarize, I kept going up the chain and saw that the first layer of my nn.Module (a Conv2d layer) just outputs nan.
I wrote a function called debug to demonstrate this. Here’s my code. you can just run this to first train the model and then check the output of the debug function.
# pytorch imports
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
import torchvision
from torchsummary import summary
# misc imports
from tqdm import tqdm
import matplotlib.pyplot as plt
from random import randint
class VariationalAutoEncoder(nn.Module):
def __init__(self, latent_dimensions=4):
super().__init__()
self.latent_dimensions = latent_dimensions
# the encoder neural net, representing the posterior probability
self.encoder = nn.Sequential(
# first convolutional layer
nn.Conv2d(in_channels=3, out_channels=8, kernel_size=(3, 3), stride=2),
nn.BatchNorm2d(num_features=8),
nn.ReLU(),
# second convolutional layer
nn.Conv2d(in_channels=8, out_channels=16, kernel_size=(3, 3), stride=2),
nn.BatchNorm2d(num_features=16),
nn.ReLU(),
# third convolutional layer
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 3), stride=1),
nn.BatchNorm2d(num_features=32),
nn.ReLU(),
# fourth convolutional layer
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3), stride=1),
nn.BatchNorm2d(num_features=64),
nn.ReLU(),
# flattening tensor before linear layers
nn.Flatten(),
# first linear layer
nn.Linear(64 * 3 * 3, 256),
nn.BatchNorm1d(num_features=256),
nn.ReLU(),
# second linear layer
nn.Linear(256, 128),
nn.BatchNorm1d(num_features=128),
)
self.mean_linear = nn.Linear(in_features=128, out_features=latent_dimensions)
self.variance_linear = nn.Linear(
in_features=128, out_features=latent_dimensions
)
# the decoder neural net, representing the likelihood probability
self.decoder = nn.Sequential(
# first linear layer
nn.Linear(in_features=latent_dimensions, out_features=128),
nn.ReLU(),
# second linear layer
nn.Linear(in_features=128, out_features=256),
nn.ReLU(),
# third linear layer
nn.Linear(in_features=256, out_features=64 * 3 * 3),
nn.ReLU(),
# unflatten for ConvTranspose2d
nn.Unflatten(dim=1, unflattened_size=(64, 3, 3)),
# first conv layer
nn.ConvTranspose2d(
in_channels=64, out_channels=32, kernel_size=(3, 3), stride=1
),
nn.BatchNorm2d(num_features=32),
nn.ReLU(),
# second conv layer
nn.ConvTranspose2d(
in_channels=32, out_channels=16, kernel_size=(3, 3), stride=1
),
nn.BatchNorm2d(num_features=16),
nn.ReLU(),
# third conv layer
nn.ConvTranspose2d(
in_channels=16, out_channels=8, kernel_size=(3, 3), stride=2
),
nn.BatchNorm2d(num_features=8),
nn.ReLU(),
# fourth conv layer
nn.ConvTranspose2d(
in_channels=8,
out_channels=3,
kernel_size=(3, 3),
stride=2,
output_padding=1,
),
nn.BatchNorm2d(num_features=3),
nn.ReLU(),
)
# a separate Kullbeik-Leibler divergence term to store after encoding
self.KLDivergence = 0.0
# model optimizer
self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
def loss(self, sample, prediction):
reconstruction_loss = torch.linalg.vector_norm((torch.linalg.matrix_norm(sample - prediction) ** 2) * 0.5,dim=1)
# return the mean loss for this batch
return torch.mean(reconstruction_loss + self.KLDivergence)
def forward(self, sample):
encoding = self.encoder(sample)
# generating parameters for encoded distribution
mean = self.mean_linear(encoding)
variance = self.variance_linear(encoding)
covariance_matrix = torch.diag_embed(variance)
# calculate KLDivergence
self.KLDivergence = (
torch.linalg.vector_norm(mean, dim=1) ** 2
+ variance.sum(dim=1)
- self.latent_dimensions
- torch.log(torch.linalg.matrix_norm(covariance_matrix))
) * 0.5
# sampling from encoded distribution
latent_variable = mean + variance * torch.distributions.Normal(0.0, 1.0).sample(
sample_shape=mean.shape
)
decoding = self.decoder(latent_variable)
return decoding
def train_epoch(self, dataloader):
self.train() # set model to training mode
epoch_loss = 0.0
# create device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
for sample, _ in dataloader:
# move sample to appropriate device
sample = sample.to(device)
self.optimizer.zero_grad() # clear gradients
prediction = self.forward(sample) # forward pass through model
loss = self.loss(sample, prediction) # calculate loss for this pass
loss.backward() # calculate gradients
self.optimizer.step() # update parameters
epoch_loss += loss.detach().item() # record loss
return epoch_loss / len(dataloader.dataset) # return average loss
def test_epoch(self, dataloader):
self.eval()
eval_loss = 0.0
# create device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
with torch.no_grad():
for sample, _ in dataloader:
# move sample to appropriate device
sample = sample.to(device)
prediction = self.forward(sample)
loss = self.loss(sample, prediction)
eval_loss += loss.detach().item()
return eval_loss / len(dataloader.dataset)
def train_model(self, plot=False):
data_dir = "dataset"
train_dataset = torchvision.datasets.CIFAR10(
data_dir, train=True, download=True
)
train_dataset.transform = torchvision.transforms.ToTensor()
m = len(train_dataset)
train_data, val_data = random_split(
train_dataset, [m - int(m * 0.2), int(m * 0.2)]
)
batch_size = 256
num_epochs = 100
train_loader = DataLoader(train_data, batch_size=batch_size)
eval_loader = DataLoader(val_data, batch_size=batch_size)
training_loss = []
evaluation_loss = []
for epoch in tqdm(range(num_epochs)):
loss = self.train_epoch(train_loader)
training_loss.append(loss)
loss = self.test_epoch(eval_loader)
evaluation_loss.append(loss)
if plot:
fig, ax = plt.subplots(1, 2, figsize=(8, 6))
plt.tight_layout()
ax[0].set_title("Training Loss")
ax[0].plot(training_loss)
ax[1].set_title("Evaluation Loss")
ax[1].plot(evaluation_loss)
plt.show()
return
def demo():
data_dir = 'dataset'
batch_size = 256
model = torch.load("vae_CIFAR.model")
model.eval()
test_dataset = torchvision.datasets.CIFAR10(data_dir, train=False, download=True)
test_dataset.transform = torchvision.transforms.ToTensor()
starting_index = randint(0,len(test_dataset)-10)
for sample_index in range(starting_index,starting_index + 10):
sample,_ = test_dataset[sample_index]
fig,ax = plt.subplots(1,2,figsize = (8,6))
plt.tight_layout()
initial_numpy = torch.permute(sample,dims = (1,2,0)).numpy().reshape(32,32,3)
ax[0].set_title("Initial Data Sample")
ax[0].imshow(initial_numpy)
generated_numpy = torch.permute(model(sample.unsqueeze(0)).squeeze().detach(),dims = (1,2,0)).numpy().reshape(32,32,3)
ax[1].set_title("Generated Data Sample")
ax[1].imshow(model(sample.unsqueeze(0)).detach().numpy().reshape(32,32,3))
plt.show()
def test():
data_dir = 'dataset'
batch_size = 256
model = torch.load("vae_CIFAR.model")
model.eval()
test_dataset = torchvision.datasets.CIFAR10(data_dir, train=False, download=True)
test_dataset.transform = torchvision.transforms.ToTensor()
test_loader = DataLoader(test_dataset,batch_size = 256)
test_loss = []
# create device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
with torch.no_grad():
for sample, _ in tqdm(test_loader):
# move sample to appropriate device
sample = sample.to(device)
prediction = model(sample)
loss = model.loss(sample, prediction)
test_loss.append(loss.detach().item())
plt.figure()
plt.title("Testing Loss")
plt.plot(test_loss)
plt.show()
return
def debug():
data_dir = 'dataset'
batch_size = 256
model = torch.load("vae_CIFAR.model").encoder
model.eval()
test_dataset = torchvision.datasets.CIFAR10(data_dir, train=False, download=True)
test_dataset.transform = torchvision.transforms.ToTensor()
sample,_ = test_dataset[0]
sample = sample.unsqueeze(0)
layer = model[0] # fetch the first conv layer
demo_layer = nn.Conv2d(in_channels=3,out_channels=8,kernel_size=(3,3),stride=2)
predict = layer(sample)
predict_demo = demo_layer(sample)
print(f"The model output: \n {predict}")
print(f"The singleton conv layer output: \n {predict_demo}")
if __name__ == "__main__":
# create device to train on GPU (if available)
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
# demo()
# test()
# create model and move to device
model = VariationalAutoEncoder().to(device)
model.train_model(plot=True)
torch.save(model, "vae_CIFAR.model")
debug()
I would appreciate any and all input. Thank you in advance