Conv2D outputs NaN in nn.Module but not on its own

I’m in the process of implementing a variational autoencoder on CIFAR10. I’ve come across a weird problem. After having written the model code, I attempted training it and saw that the model didn’t learn anything at all.

Then I used pdb to see where this problem came from and saw that the loss was just nan. Then I checked how the loss was calculated and saw that the reconstruction loss was the source of the nan problem. To summarize, I kept going up the chain and saw that the first layer of my nn.Module (a Conv2d layer) just outputs nan.

I wrote a function called debug to demonstrate this. Here’s my code. you can just run this to first train the model and then check the output of the debug function.

# pytorch imports
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
import torchvision
from torchsummary import summary

# misc imports
from tqdm import tqdm
import matplotlib.pyplot as plt
from random import randint


class VariationalAutoEncoder(nn.Module):
    def __init__(self, latent_dimensions=4):
        super().__init__()

        self.latent_dimensions = latent_dimensions

        # the encoder neural net, representing the posterior probability
        self.encoder = nn.Sequential(
            # first convolutional layer
            nn.Conv2d(in_channels=3, out_channels=8, kernel_size=(3, 3), stride=2),
            nn.BatchNorm2d(num_features=8),
            nn.ReLU(),
            # second convolutional layer
            nn.Conv2d(in_channels=8, out_channels=16, kernel_size=(3, 3), stride=2),
            nn.BatchNorm2d(num_features=16),
            nn.ReLU(),
            # third convolutional layer
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 3), stride=1),
            nn.BatchNorm2d(num_features=32),
            nn.ReLU(),
            # fourth convolutional layer
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3), stride=1),
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(),
            # flattening tensor before linear layers
            nn.Flatten(),
            # first linear layer
            nn.Linear(64 * 3 * 3, 256),
            nn.BatchNorm1d(num_features=256),
            nn.ReLU(),
            # second linear layer
            nn.Linear(256, 128),
            nn.BatchNorm1d(num_features=128),
        )

        self.mean_linear = nn.Linear(in_features=128, out_features=latent_dimensions)

        self.variance_linear = nn.Linear(
            in_features=128, out_features=latent_dimensions
        )

        # the decoder neural net, representing the likelihood probability
        self.decoder = nn.Sequential(
            # first linear layer
            nn.Linear(in_features=latent_dimensions, out_features=128),
            nn.ReLU(),
            # second linear layer
            nn.Linear(in_features=128, out_features=256),
            nn.ReLU(),
            # third linear layer
            nn.Linear(in_features=256, out_features=64 * 3 * 3),
            nn.ReLU(),
            # unflatten for ConvTranspose2d
            nn.Unflatten(dim=1, unflattened_size=(64, 3, 3)),
            # first conv layer
            nn.ConvTranspose2d(
                in_channels=64, out_channels=32, kernel_size=(3, 3), stride=1
            ),
            nn.BatchNorm2d(num_features=32),
            nn.ReLU(),
            # second conv layer
            nn.ConvTranspose2d(
                in_channels=32, out_channels=16, kernel_size=(3, 3), stride=1
            ),
            nn.BatchNorm2d(num_features=16),
            nn.ReLU(),
            # third conv layer
            nn.ConvTranspose2d(
                in_channels=16, out_channels=8, kernel_size=(3, 3), stride=2
            ),
            nn.BatchNorm2d(num_features=8),
            nn.ReLU(),
            # fourth conv layer
            nn.ConvTranspose2d(
                in_channels=8,
                out_channels=3,
                kernel_size=(3, 3),
                stride=2,
                output_padding=1,
            ),
            nn.BatchNorm2d(num_features=3),
            nn.ReLU(),
        )

        # a separate Kullbeik-Leibler divergence term to store after encoding
        self.KLDivergence = 0.0

        # model optimizer
        self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)

    def loss(self, sample, prediction):
        
        reconstruction_loss = torch.linalg.vector_norm((torch.linalg.matrix_norm(sample - prediction) ** 2) * 0.5,dim=1)

        # return the mean loss for this batch
        return torch.mean(reconstruction_loss + self.KLDivergence)

    def forward(self, sample):
        encoding = self.encoder(sample)

        # generating parameters for encoded distribution
        mean = self.mean_linear(encoding)
        variance = self.variance_linear(encoding)

        covariance_matrix = torch.diag_embed(variance)

        # calculate KLDivergence
        self.KLDivergence = (
            torch.linalg.vector_norm(mean, dim=1) ** 2
            + variance.sum(dim=1)
            - self.latent_dimensions
            - torch.log(torch.linalg.matrix_norm(covariance_matrix))
        ) * 0.5

        # sampling from encoded distribution
        latent_variable = mean + variance * torch.distributions.Normal(0.0, 1.0).sample(
            sample_shape=mean.shape
        )

        decoding = self.decoder(latent_variable)

        return decoding

    def train_epoch(self, dataloader):
        self.train()  # set model to training mode
        epoch_loss = 0.0

        # create device 
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        
        for sample, _ in dataloader:
            
            # move sample to appropriate device
            sample = sample.to(device)
            
            self.optimizer.zero_grad()  # clear gradients

            prediction = self.forward(sample)  # forward pass through model

            loss = self.loss(sample, prediction)  # calculate loss for this pass

            loss.backward()  # calculate gradients

            self.optimizer.step()  # update parameters

            epoch_loss += loss.detach().item()  # record loss

        return epoch_loss / len(dataloader.dataset)  # return average loss

    def test_epoch(self, dataloader):
        self.eval()
        eval_loss = 0.0

        # create device 
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

        with torch.no_grad():
            for sample, _ in dataloader:

                # move sample to appropriate device
                sample = sample.to(device)
                prediction = self.forward(sample)

                loss = self.loss(sample, prediction)

                eval_loss += loss.detach().item()

        return eval_loss / len(dataloader.dataset)

    def train_model(self, plot=False):
        data_dir = "dataset"

        train_dataset = torchvision.datasets.CIFAR10(
            data_dir, train=True, download=True
        )

        train_dataset.transform = torchvision.transforms.ToTensor()

        m = len(train_dataset)

        train_data, val_data = random_split(
            train_dataset, [m - int(m * 0.2), int(m * 0.2)]
        )

        batch_size = 256
        num_epochs = 100

        train_loader = DataLoader(train_data, batch_size=batch_size)
        eval_loader = DataLoader(val_data, batch_size=batch_size)

        training_loss = []
        evaluation_loss = []

        for epoch in tqdm(range(num_epochs)):
            loss = self.train_epoch(train_loader)
            training_loss.append(loss)

            loss = self.test_epoch(eval_loader)
            evaluation_loss.append(loss)

        if plot:
            fig, ax = plt.subplots(1, 2, figsize=(8, 6))
            plt.tight_layout()

            ax[0].set_title("Training Loss")
            ax[0].plot(training_loss)

            ax[1].set_title("Evaluation Loss")
            ax[1].plot(evaluation_loss)

            plt.show()

        return

def demo():
    data_dir = 'dataset'

    batch_size = 256
    
    model = torch.load("vae_CIFAR.model")
    model.eval()

    test_dataset = torchvision.datasets.CIFAR10(data_dir, train=False, download=True)
    test_dataset.transform = torchvision.transforms.ToTensor()

    starting_index = randint(0,len(test_dataset)-10)

    for sample_index in range(starting_index,starting_index + 10): 

        sample,_ = test_dataset[sample_index]

        fig,ax = plt.subplots(1,2,figsize = (8,6))
        plt.tight_layout()

        initial_numpy = torch.permute(sample,dims = (1,2,0)).numpy().reshape(32,32,3)

        ax[0].set_title("Initial Data Sample")
        ax[0].imshow(initial_numpy)

        generated_numpy = torch.permute(model(sample.unsqueeze(0)).squeeze().detach(),dims = (1,2,0)).numpy().reshape(32,32,3)
        ax[1].set_title("Generated Data Sample")
        ax[1].imshow(model(sample.unsqueeze(0)).detach().numpy().reshape(32,32,3))

        plt.show()

def test():
    
    data_dir = 'dataset'

    batch_size = 256
    
    model = torch.load("vae_CIFAR.model")
    model.eval()

    test_dataset = torchvision.datasets.CIFAR10(data_dir, train=False, download=True)
    test_dataset.transform = torchvision.transforms.ToTensor()

    test_loader = DataLoader(test_dataset,batch_size = 256)

    test_loss = []

    # create device 
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    with torch.no_grad():
        for sample, _ in tqdm(test_loader):

            # move sample to appropriate device
            sample = sample.to(device)
            prediction = model(sample)

            loss = model.loss(sample, prediction)

            test_loss.append(loss.detach().item())

    plt.figure()
    plt.title("Testing Loss")
    plt.plot(test_loss)

    plt.show()

    return

def debug():
    
    data_dir = 'dataset'

    batch_size = 256
    
    model = torch.load("vae_CIFAR.model").encoder
    model.eval()

    test_dataset = torchvision.datasets.CIFAR10(data_dir, train=False, download=True)
    test_dataset.transform = torchvision.transforms.ToTensor()

    sample,_ = test_dataset[0]

    sample = sample.unsqueeze(0)

    layer = model[0] # fetch the first conv layer

    demo_layer = nn.Conv2d(in_channels=3,out_channels=8,kernel_size=(3,3),stride=2)

    predict = layer(sample)

    predict_demo = demo_layer(sample)

    print(f"The model output: \n {predict}")
    print(f"The singleton conv layer output: \n {predict_demo}")

if __name__ == "__main__":
    # create device to train on GPU (if available)
    device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

    # demo()
    # test()

    # create model and move to device
    model = VariationalAutoEncoder().to(device)
    model.train_model(plot=True)

    torch.save(model, "vae_CIFAR.model")
    
    debug()

I would appreciate any and all input. Thank you in advance

1 Like