Loss diverges while training unet

I am trying to generate potrait images, using image segmentation. I am using unet with the following architecture.

import torch
import torch.nn as nn
class Unet(nn.Module):
    '''U-Net Architecture'''
    def __init__(self,inp,out):
        super(Unet,self).__init__()
        self.c1=self.contracting_block(inp,64)
        self.c2=self.contracting_block(64,128)
        self.c3=self.contracting_block(128,256)
        self.c4=self.contracting_block(256,512)
        self.c5=self.contracting_block(512,1024)
        self.maxpool=nn.MaxPool2d(2)
        self.upsample=nn.Upsample(scale_factor=2,mode="bilinear",align_corners=True)
        self.c6=self.contracting_block(512+1024,512)
        self.c7=self.contracting_block(512+256,256)
        self.c8=self.contracting_block(256+128,128)
        self.c9=self.contracting_block(128+64,64)
        self.c10=nn.Conv2d(64,1,1)
        

    def contracting_block(self,inp,out,k=3):
        block =nn.Sequential(
            nn.Conv2d(inp, out, padding=1,kernel_size=3),
            nn.BatchNorm2d(out),
            nn.ReLU(inplace=True),
            nn.Conv2d(out, out,padding=1,kernel_size=3),
            nn.BatchNorm2d(out),
            nn.ReLU(inplace=True)
        )
        return block


    def forward(self,x):
        conv1=self.c1(x) #256x256x64
        conv1=self.maxpool(conv1) #128x128x64
        conv2=self.c2(conv1) #128x128x128
        conv2=self.maxpool(conv2) #64x64x128
        conv3=self.c3(conv2) #64x64x256
        conv3=self.maxpool(conv3) #32x32x256
        conv4=self.c4(conv3) #32x32x512
        conv4=self.maxpool(conv4) #16x16x512
        conv5=self.c5(conv4) #8x8x1024
        conv5=self.maxpool(conv5)
        x=self.upsample(conv5) ##16x16x1024
        #print(x.shape)
        x=torch.cat([x,conv4],axis=1) #16x16x1536
        x=self.c6(x) #16x16x512
        x=self.upsample(x) #32x32x512
        x=torch.cat([x,conv3],axis=1) 
        x=self.c7(x) #32x32x256
        x=self.upsample(x) #64x64x256
        x=torch.cat([x,conv2],axis=1)
        x=self.c8(x) #64x64x128
        x=self.upsample(x) #128x128x128
        x=torch.cat([x,conv1],axis=1) 
        x=self.c9(x) #128x128x64
        x=self.upsample(x)#256x256x64
        x=self.c10(x)
        return x


if __name__=="__main__":
    x=torch.ones(1,3,256,512)
    net=Unet(3,1)
    print(net(x).shape)

I am using a publicly available dataset with almost 1300 images.
My loss always diverges, I am using a learning rate of 1e-5, loss function-> BCEWithLogitsLoss
My training loop is as follows.

def training_loop(*args,**kwargs):
    """
    Main training Loop
    keyword parameters:
    epochs:number of epochs
    lr:learning_rate
    
    """
    global net,valid_loader,train_loader,device
    epochs=kwargs["epochs"]
    lr=kwargs["lr"]
    if(os.path.isdir("checkpoints")==False):
        os.mkdir("checkpoints")
    criterion=None
    if model=="unet":
        criterion=nn.BCEWithLogitsLoss()
    elif model=="cnn":
        criterion=nn.CrossEntropyLoss()
    opt=optim.Adam(net.parameters(),lr=lr,weight_decay=1e-8)
    xx=[]
    yy=[]
    for epoch_num in range(1,epochs+1):
        running_loss=0.0
        for i,samples in enumerate(train_loader):

            imgs,masks=samples[0],samples[1]
            imgs,masks=imgs.to(device),masks.to(device)
            opt.zero_grad()
            outputs=net(imgs)
            loss=criterion(outputs,masks)
            loss.backward()
            opt.step()
            if(model=="unet"):
                running_loss += torch.exp(loss).item()
            elif(model=="cnn"):
                running_loss+=loss.item()

            if(i%20==19):
                valid_loss=validation(valid_loader=valid_loader,criterion=criterion)
                writer.add_scalars("first",{'train_loss':torch.tensor(running_loss/20),
                                            'validation_loss':torch.tensor(valid_loss)},epoch_num*len(train_loader)+i)

                writer.close()
                print("Epoch [%3d] iteration [%4d] loss:[%.10f]"%(epoch_num,i,running_loss/20),end="")
                print(" validation_loss:[%.10f]"%(valid_loss))
                running_loss=0.0
        torch.save(net.state_dict(),"checkpoints/"+str(epoch_num)+".pth")
        

I am not able to find the problem. Please help.

The model and training code look alright and your model is able to learn some random inputs:

if __name__=="__main__":
    x=torch.ones(1,3,256,512)
    net=Unet(3,1)
    print(net(x).shape)

    net.cuda()
    optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)

    data = torch.randn(10, 3, 256, 512, device='cuda')
    target = torch.randint(0, 2, (10, 1, 256, 512)).float().cuda()
    criterion = nn.BCEWithLogitsLoss()

    for epoch in range(100):
        optimizer.zero_grad()
        output = net(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        print('epoch {}, loss {}'.format(epoch, loss.item()))

What issues are you seeing and could you check your target range in case you are seeing an increasing loss?

Sir, I normalised the images and masks to have mean 0.5 and std 0.5. I also rechecked and the min and max values of images and masks are:

image -->max and min  tensor(1.) tensor(-0.9294)
mask--> max ad min tensor(1.) tensor(-1.)

What should I do ?

Try to scale down your use case by overfitting a small subset of your data, e.g. just 10 samples.
If this doesn’t work, play around with some hyperparameters.
In case your model is not able to overfit these samples, you would have to check the architecture again or there might be a bug in the complete training routine which is not shown here or which I have missed.

Any reason why you do this?

running_loss += torch.exp(loss).item()

Why aren’t you simply summing the losses?

Sir, I am using BCEWithLogitsLoss as my loss function. To bring the loss in range [0,infinity] I take the exponent and I am printing the average loss after some steps.

Sir, I tried training on a small set of data. The loss becomes “nan” after some iterations. I also tried some other values of batch size and learning rate. But still no improvement.

Entire training code is ->

import os
import cv2
import sys
import math
import torch
import argparse
import torchvision
import numpy as np
sys.path.append('')
import torch.nn as nn
import torch.optim as optim
from modelArch.unet import Unet
from modelArch.cnn import Cnn
from torchsummary import summary
from dataLoader.dataLoader_unet import load
from dataLoader.dataloader_cnn import load_cnn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

parser=argparse.ArgumentParser()
parser.add_argument("--epochs",default=50)
parser.add_argument("--batch_size",default=4)
parser.add_argument("--lr",default=0.0001)
parser.add_argument("--model",default="unet",help="unet/cnn")
args=parser.parse_args()

writer=SummaryWriter('runs/trial1')

net=None
valid_loader=None
train_loader=None
device=None
data=None
model=None

def weights_init(m):
    if isinstance(m,nn.Conv2d):
        torch.nn.init.xavier_uniform_(m.weight)
        torch.nn.init.zeros_(m.bias)

def init(*args,**kwargs):
    """
    Initiates the training process
    keyword parameters:
    train_percent:[0,1]
    resume:pass checkpoint number from where to resume training
    batch_size
    """
    #resume=kwargs["resume"]
    global net,valid_loader,train_loader,device,data
    resume=None
    train_percent=kwargs["train_percent"]
    batch_size=kwargs["batch_size"]
    width=kwargs["width"]
    height=kwargs["width"]
    #model=kwargs["model"]
    if(model=="unet"):
        net=Unet(3,1)
        data=load(width=width,height=height)

    if(resume is not None):
        net.load_state_dict(torch.load("checkpoints/"+str(name)+".pth"))
        print("Resuming training from "+str(name)+" checkpoint")
    else:
        net.apply(weights_init)
        
    device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using ",device)
    #summary(net,input_size=(3,256,256))
    
    size=len(data)
    train_size=math.floor(train_percent*size)
    test_size=size-train_size
    print("Data Loaded")
    train,validation=torch.utils.data.random_split(data,[train_size,test_size])
    train_loader=DataLoader(train,batch_size=batch_size,shuffle=True,num_workers=4)
    valid_loader=DataLoader(validation,batch_size=batch_size,shuffle=True,num_workers=4)
    #x=iter(dataLoader)
    #img,mask=x.next()
    #grid=torchvision.utils.make_grid(img)
    #writer.add_image('images',grid,0)
    #writer.add_graph(net,img)
    #writer.close()
    net.to(device)

def validation(**kwargs):

    """
    keyword args:
    valid_loader: validation data loader
    """
    global net,valid_loader,train_loader,device
    valid_loader=kwargs["valid_loader"]
    criterion=kwargs["criterion"]
    #model=kwargs["model"]
    p=0
    valid_loss=0.0
    with torch.no_grad():
        for no,data in enumerate(valid_loader):
            imgs,masks=data[0].to(device),data[1].to(device)
            outputs=net(imgs)
            v_loss=criterion(outputs,masks)
            if(model=='unet'):
                valid_loss+=torch.exp(v_loss).item()       
            p+=1
    
    return valid_loss/p


def training_loop(*args,**kwargs):
    """
    Main training Loop
    keyword parameters:
    epochs:number of epochs
    lr:learning_rate
    
    """
    global net,valid_loader,train_loader,device
    epochs=kwargs["epochs"]
    lr=kwargs["lr"]
    if(os.path.isdir("checkpoints")==False):
        os.mkdir("checkpoints")
    criterion=None
    if model=="unet":
        criterion=nn.BCEWithLogitsLoss()
    opt=optim.Adam(net.parameters(),lr=lr,weight_decay=1e-8)
    xx=[]
    yy=[]
    for epoch_num in range(1,epochs+1):
        running_loss=0.0
        for i,samples in enumerate(train_loader):

            imgs,masks=samples[0],samples[1]
            imgs,masks=imgs.to(device),masks.to(device)
            opt.zero_grad()
            outputs=net(imgs)
            loss=criterion(outputs,masks)
            loss.backward()
            opt.step()
            if(model=="unet"):
                running_loss += torch.exp(loss).item()

            if(i%20==19):
                valid_loss=validation(valid_loader=valid_loader,criterion=criterion)
                writer.add_scalars("first",{'train_loss':torch.tensor(running_loss/20),
                                            'validation_loss':torch.tensor(valid_loss)},epoch_num*len(train_loader)+i)

                writer.close()
                print("Epoch [%3d] iteration [%4d] loss:[%.10f]"%(epoch_num,i,running_loss/20),end="")
                print(" validation_loss:[%.10f]"%(valid_loss))
                running_loss=0.0
        torch.save(net.state_dict(),"checkpoints/"+str(epoch_num)+".pth")
        
    
if __name__=="__main__":
    #global model
    model=args.model
    init(batch_size=int(args.batch_size),train_percent=0.95,width=256,height=512)
    training_loop(epochs=int(args.epochs),lr=1e-4)


    

Is the loss still going up? Also, could you check your input for NaN values?

Sir, I am sorry for the mistake, the loss quickly becomes “inf” not “nan”. I check the input as well and I don’t find any problem with the input.

self.maxpool=nn.MaxPool2d(2)

Just a suggestion. I think you could try nn.MaxPool2d(2) instead of using self.maxpool() repeatly. I don’t know whether it will cause a ploblem because all gradients flow through the node self.maxpool() (only one node) in graph.

I tried replacing self.maxpool with nn.MaxPool2d(2), but still no success, the loss decreases for a while and then quickly becomes infinity. My dataLoader looks likes this

import os
import cv2
import random
import numpy as np
from tqdm import tqdm
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms




class load(Dataset):
    def __init__(self,**kwargs):
        self.width=kwargs["width"]
        self.height=kwargs["height"]
        self.samples=[]
        self.path1="/home/satinder/Desktop/deepWay/DeepWay.v2/dataSet/Segmentation2/img/"
        self.path2="/home/satinder/Desktop/deepWay/DeepWay.v2/dataSet/Segmentation2/mask/"
        img_folder=os.listdir(self.path1)
        
        for i in tqdm(img_folder):
            num=i.split(".")[0]
            self.samples.append((i,num+".png"))
        self.color=transforms.ColorJitter(brightness = 1)
        #self.translate=transforms.RandomAffine(translate=(0.1,0.1))
        self.angle=transforms.RandomAffine(degrees=(60))
        self.flip=transforms.RandomHorizontalFlip(p=0.5)
        self.transforms_img=transforms.Compose([transforms.ToTensor(),
                                                transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])

        self.transforms_mask=transforms.Compose([transforms.Grayscale(num_output_channels=1),
                                                transforms.ToTensor(),
                                                transforms.Normalize((0.5,),(0.5,))])

    def __len__(self):
        return len(self.samples)

    def __getitem__(self,idx):
        i,j=self.samples[idx]
        img=cv2.imread(self.path1+i,1)
        img=cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
        #img=cv2.blur(img,(3,3))
        mask=cv2.imread(self.path2+j,1)
        mask=cv2.cvtColor(mask,cv2.COLOR_BGR2GRAY)
        mask=cv2.Canny(mask,100,150)
        mask=cv2.dilate(mask,None,iterations=5)
        img=cv2.resize(img,(self.height,self.width))
        mask=cv2.resize(mask,(self.height,self.width))
        #print(mask.shape)
        seed=np.random.randint(2147483647)
        img=Image.fromarray(img)
        mask=Image.fromarray(mask)
        

        random.seed(seed)
        #img=self.color(img)
        random.seed(seed)
        #img=self.translate(img)
        random.seed(seed)
        #img=self.angle(img)
        random.seed(seed)
        #img=self.flip(img)
        random.seed(seed)
        img=self.transforms_img(img)
        
        random.seed(seed)
        #mask=self.translate(mask)
        random.seed(seed)
        #mask=self.angle(mask)
        random.seed(seed)
        #mask=self.flip(mask)
        random.seed(seed)
        mask=self.transforms_mask(mask)
        #print(img)
        return (img,mask)
    
    def plot(self,img):
        img=np.transpose(img.numpy(),(1,2,0))
        img=img*0.5+0.5
        img=cv2.cvtColor(img,cv2.COLOR_RGB2BGR)
        cv2.imshow("ds",img)
        cv2.waitKey(0)


if(__name__=="__main__"):
    obj=load(width=256,height=256)
    res=obj.__getitem__(7)
    obj.plot(res[0])
    obj.plot(res[1])
    #cv2.imshow("img",res[0].cpu().detach().numpy())
    

Any suggestions?

Maybe you are using your labels in the wrong way … why using canny filter?

Sir, it is because I was trying to only predict the boundary of the mask. I have tried it without applying canny as well and the results are same

Any other suggestions?

Try to pass the mask with values in the range [0, 1].

Perhaps there is a problem in the data.

1 Like

There was a problem with the dataset, I changed it and now it works properly.

1 Like