Loss is calculated as NaN when using amp.autocast

This is my train function
def _train(self):
self.model.train() # train mode
train_losses = # accumulate the losses here

        batch_iter = tqdm(

        for i, (x, y) in batch_iter:
            input, target = x.to(self.device), y.to(
            )  # send to device (GPU or CPU)

            self.optimizer.zero_grad()  # zerograd the parameters

            with torch.cuda.amp.autocast():
                out = self.model(input)
                # out = out.float()
                loss = self.criterion(out, target)  # calculate loss

            # print(dice_loss(target,out))
            # loss += self.criterion2(out,target)

            # torch.cuda.empty_cache()

            loss_value = loss

                f"Training: (loss {loss_value:.4f})"
            )  # update progressbar

        # plt.plot(plt_trainloss)

I am doing a binary segmentation task where my model returns logits. when I trued to use the AMP, I get my loss as NAN. Can anyone help me figure out whats wrong here ?

Try to narrow down where the first invalid value is created by checking the loss, calculation as well as the intermediates of your forward pass.

Hey thanks for replying ! I’m Pretty new to pytorch and deep learning in general.
The NaN loss occurs at the first epoch, when the first batch is iterated
I Tried to print the loss here
with torch.cuda.amp.autocast():
out = self.model(input)

            loss = self.criterion(out, target)  # calculate loss

I get this output:
ensor(nan, device=‘cuda:0’, grad_fn=)
The loss is calculated as nan automatically in the autocast loop before the gradients can be updated

I dont know the Model give NaN outputs on the first batch in autocast, It works well without autocast:

Is something wrong with the model ? Should I modify anything here
I’m Implementing a Unet here:
“”" Full assembly of the parts to form the complete network “”"
import torch
import torch.nn as nn
import torch.nn.functional as F

class Unet(nn.Module):
def init(self, n_channels, n_classes, bilinear=False):
super(Unet, self).init()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear

    self.inc = (DoubleConv(n_channels, 64))
    self.down1 = (Down(64, 128))
    self.down2 = (Down(128, 256))
    self.down3 = (Down(256, 512))
    factor = 2 if bilinear else 1
    self.down4 = (Down(512, 1024 // factor))
    self.up1 = (Up(1024, 512 // factor, bilinear))
    self.up2 = (Up(512, 256 // factor, bilinear))
    self.up3 = (Up(256, 128 // factor, bilinear))
    self.up4 = (Up(128, 64, bilinear))
    self.outc = (OutConv(64, n_classes))

def forward(self, x):
    x1 = self.inc(x)
    x2 = self.down1(x1)
    x3 = self.down2(x2)
    x4 = self.down3(x3)
    x5 = self.down4(x4)
    x = self.up1(x5, x4)
    x = self.up2(x, x3)
    x = self.up3(x, x2)
    x = self.up4(x, x1)
    logits = self.outc(x)
    return logits

class DoubleConv(nn.Module):
“”“(convolution => [BN] => ReLU) * 2"”"

def __init__(self, in_channels, out_channels, mid_channels=None):
    if not mid_channels:
        mid_channels = out_channels
    self.double_conv = nn.Sequential(
        nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
        nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),

def forward(self, x):
    return self.double_conv(x)

class Down(nn.Module):
“”“Downscaling with maxpool then double conv”“”

def __init__(self, in_channels, out_channels):
    self.maxpool_conv = nn.Sequential(
        DoubleConv(in_channels, out_channels)

def forward(self, x):
    return self.maxpool_conv(x)

class Up(nn.Module):
“”“Upscaling then double conv”“”

def __init__(self, in_channels, out_channels, bilinear=True):

    # if bilinear, use the normal convolutions to reduce the number of channels
    if bilinear:
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)

def forward(self, x1, x2):
    x1 = self.up(x1)
    # input is CHW
    diffY = x2.size()[2] - x1.size()[2]
    diffX = x2.size()[3] - x1.size()[3]

    x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                    diffY // 2, diffY - diffY // 2])
    # if you have padding issues, see
    # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
    # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
    x = torch.cat([x2, x1], dim=1)
    return self.conv(x)

class OutConv(nn.Module):
def init(self, in_channels, out_channels):
super(OutConv, self).init()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

def forward(self, x):
    return self.conv(x)