Issue with optimizer.step()

I am currently trying to train a model by converting tensorflow code to pytorch and I am stuck on an issue that I can not figure out.
When I comment out optimizer.step() during the training loop Im able to print out each batch prediction tensor, but once I put in optimizer.step() the tensors fill up with nan. I used a custom loss function and custom layer that I believe coded correctly. I am not sure where the issue is at.

Here is my code if someone has the time to help diagnose!

Edit: Please note I am returning 3 tensors in my “main” model Variational_autoencoder. I return
z_mean, z_log_var, reconstruction. I take these 3 tensors to calculate my loss via the loss function.

# defining custom pytorch layer for sampling a random z from distribution
class Sampling(nn.Module):
    def __init__(self):
        super().__init__()

    @staticmethod
    def forward(z_mean, z_log_var, training=False):
        batch = z_mean.size(0)
        dim = z_mean.size(1)
        # during training, we want to sample from z_log_var, thus get epsilon
        if training:
            # epsilon is a random sample from standard normal distribution
            # do this for 2x2 dimensions
            epsilon = torch.normal(mean=0.0, std=1.0, size=(batch, dim)).to(device)
        # during testing, we don't want variance to play a role
        else:
            # make epsilon 0 so that we don't have variance and just the mean
            epsilon = torch.zeros(batch, dim).to(device)
        # sigma of z_mean is exp(log(sigma^2)/2)
        # thus z_mean + sigma * (normal distribution sample) = z_sample
        return z_mean + torch.exp(0.5 * z_log_var) * epsilon


# compresses high-dimensional input data such as an image into a lower-dimensional embedding vector
# encodes to a latent space z
# embedding space and latent space are the same thing
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        # first convolution
        self.conv1 = nn.Sequential(
            # taking a grayscale channel and converting it to 32 features
            # reducing the y*x dimension for each feature map by half
            nn.Conv2d(in_channels=1, out_channels=32,
                      kernel_size=3, stride=2,
                      padding=1),
            nn.ReLU()
        )
        # second convolution
        self.conv2 = nn.Sequential(
            # taking 32 feature maps and creating 64 feature maps out of it
            # reducing the y*x dimension for each feature map by half
            nn.Conv2d(in_channels=32, out_channels=64,
                      kernel_size=3, stride=2,
                      padding=1),
            nn.ReLU()
        )
        # third convolution
        self.conv3 = nn.Sequential(
            # taking 64 feature maps and creating 128 feature maps out of it
            # reducing the y*x dimension for each feature map by half
            nn.Conv2d(in_channels=64, out_channels=128,
                      kernel_size=3, stride=2,
                      padding=1),
            nn.ReLU()
        )
        # take flattened layer and map to 2 nodes to represent 2-dimensional latent space
        self.fc1 = nn.Linear(2048, 2)
        self.fc2 = nn.Linear(2048, 2)
        # taking the 2 parameters from fc1(z_mean) and the 2 parameters from fc2(z_log_var), thus 4 as input
        self.sample = Sampling()

    def forward(self, input_data, training=False):
        x = self.conv1(input_data)
        x = self.conv2(x)
        x = self.conv3(x)
        flatten = nn.Flatten()(x)
        z_mean = self.fc1(flatten)
        z_log_var = self.fc2(flatten)
        z = self.sample(z_mean, z_log_var, training)
        print(z)
        return z, z_mean, z_log_var


class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        # take 2 nodes and fully connect to 2048 nodes
        self.fc1 = nn.Linear(2, 2048)
        # apply a reshape from a vector of 2048 to (128,4,4)
        # first transposed convolution
        self.tp1 = nn.Sequential(
            # take 128 channels and convert to 128 channels
            # here we get (128, 8, 8,), thus expanding channels my double
            nn.ConvTranspose2d(
                in_channels=128, out_channels=128,
                kernel_size=3, stride=2,
                padding=1, output_padding=1
            ),
            nn.ReLU()
        )
        # second transposed convolution
        self.tp2 = nn.Sequential(
            # take 128 channels and convert to 64 channels
            # here we get (64, 16, 16), thus we reduce channels by half and double channels width and length
            nn.ConvTranspose2d(
                in_channels=128, out_channels=64,
                kernel_size=3, stride=2,
                padding=1, output_padding=1
            ),
            nn.ReLU()
        )
        # third transposed convolution
        self.tp3 = nn.Sequential(
            # take 64 channels and convert to 32 channels
            # here we get (32, 32, 32), thus we reduce channels by half and double channels width and length
            nn.ConvTranspose2d(
                in_channels=64, out_channels=32,
                kernel_size=3, stride=2,
                padding=1, output_padding=1
            ),
            nn.ReLU()
        )
        # apply convolution
        # take 32 channels and convert to 1 channel (meant to represent the gray channel)
        # convolution will give (1, 32, 32)
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=32, out_channels=1,
                kernel_size=3, stride=1,
                padding=1
            ),
            # apply sigmoid activation to get a 0-1 pixel range
            nn.Sigmoid()
        )

    def forward(self, input_data):
        x = self.fc1(input_data)
        x = torch.reshape(x, (-1, 128, 4, 4))
        x = self.tp1(x)
        x = self.tp2(x)
        x = self.tp3(x)
        x = self.conv1(x)
        return x


# combining encoder and decoder to make autoencoder
class Variational_autoencoder(nn.Module):
    # initialize autoencoder with encoder and decoder objects
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, input_data, training=False):
        # first pass image to be decoded to latent space
        z, z_mean, z_log_var = self.encoder(input_data, training)
        # decode image from latent space back to pixel space
        reconstruction = self.decoder(z)
        return z_mean, z_log_var, reconstruction


# custom loss function
def loss_fn(z_mean, z_log_var, reconstruction, input):
    reconstruction_loss = torch.mean(
        500 * F.binary_cross_entropy(input, reconstruction)
    )
    kl_loss = torch.mean(
        torch.sum(-0.5 * (1 + z_log_var - torch.square(z_mean) - torch.exp(z_log_var)), dim=1)
    )
    total_loss = reconstruction_loss + kl_loss
    return total_loss


def train(model, epochs, trainloader, device):
    # lists to store loss and accuracy
    loss_hist = [0] * epochs
    accuracy_hist = [0] * epochs
    for epoch in range(epochs):
        for x_batch, y_batch in trainloader:
            # put the batches on the device
            # only need x batch for autoencoder training
            x_batch = x_batch.to(device)
            # pass through encoder, then decode back and get result as prediction
            z_mean, z_log_var, reconstruction = model(x_batch, True)
            # loss is bce between individual pixels and the decoded reconstruction individual pixels
            loss = loss_fn(z_mean, z_log_var, reconstruction, x_batch)

            # back propagation
            loss.backward()
            # adam optimization to adjust weights
            optimizer.step()
            # zero out gradients for next batch
            optimizer.zero_grad()

            # add to total loss for epoch
            # loss.item() finds loss average of the total batch
            # x_batch.size(0) gets the amount of samples in a batch
            # thus we multiply both to demonstrate avg total loss sum
            loss_hist[epoch] += loss.item() * x_batch.size(0)
            # divide total loss by length of dataset to get avg loss
        loss_hist[epoch] /= len(trainloader.dataset)
        print(f'Epoch {epoch} loss:{loss_hist[epoch]}')
        print('-' * 60)

Check the gradients of all trainable parameters e.g. via:

for name, param in model.named_parameters():
    print(f"{name} isfinite: {param.isfinite().all()}, grad isfinite: {param.grad.isfinite().all()}")

and make sure no gradients are invalid as it seems the optimizer.step() call might update the parameters to NaN or Inf values.

I assume I put that after the training? My output is: AttributeError: ‘NoneType’ object has no attribute ‘isfinite’

This is my main function to run everything:

if __name__ == "__main__":
    # seeding to prevent randomization during different runs
    torch.manual_seed(1)
    torch.backends.cudnn.deterministic = True
    # checking if cuda is available
    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    print(f"Using {device}")
    encoder = Encoder().to(device)
    decoder = Decoder().to(device)
    variational_autoencoder = Variational_autoencoder(encoder, decoder).to(device)
    learning_rate = 0.001
    # summary(variational_autoencoder, (1, 32, 32))
    optimizer = torch.optim.Adam(variational_autoencoder.parameters(), lr=learning_rate)
    batch_size = 100
    epochs = 1
    train(variational_autoencoder, epochs, trainloader, device)
    for name, param in variational_autoencoder.named_parameters():
        print(f"{name} isfinite: {param.isfinite().all()}, grad isfinite: {param.grad.isfinite().all()}")

    a = 1

I just replaced my custom loss function to nn.BCELoss() and pass in the models returned reconstruction tensor and compare it to the x_batch tensor. It must be something wrong with my custom loss function.

Wow I found the solution. It was a silly mistake that I swear took 4 hours to diagnose by me lol. It was my loss function, tensorflow and pytorch binary cross entropys 1st and 2nd parameters are switched!

Pytorch has F.binary_cross_entropy(predict, target) while tensorflow has losses.binary_crossentropy(target, predict).

I wish this book was written in pytorch! Oh well its making me understand pytorch better by forcing me to translate the code.

Putting this here to help anyone out there also making this mistake.
Love the pytorch community!

Edit: Added my fixed loss function here

def loss_fn(z_mean, z_log_var, reconstruction, input):
    reconstruction_loss = torch.mean(
        500 * F.binary_cross_entropy(reconstruction, input)
    )
    kl_loss = torch.mean(
        torch.sum(-0.5 * (1 + z_log_var - torch.square(z_mean) - torch.exp(z_log_var)), dim=1)
    )
    total_loss = reconstruction_loss + kl_loss
    return total_loss