Unstable validation loss with constantly decreasing training loss

I am training a FCN-alike model for semantic segmentation. The output of model is [batch, 2, 224, 224], and the target is [batch, 224, 224]. I used nn.CrossEntropyLoss() as the loss function.

As for the training process, I randomly split my dataset into train and validation with [0.8, 0.2] random split.
The code for training and validation are as follows:

def train(epoch, model, optimizer, criterion, train_loader, device):
    gc.collect()
    model.train()
    train_loss = 0
    for batch_idx, (img, mask) in enumerate(train_loader):
        img, mask = img.to(device), mask.to(device)
        optimizer.zero_grad()
        score = model(img)
        loss = criterion(score, mask)
        train_loss += loss
        loss.backward()
        optimizer.step()
    return train_loss

def validate(epoch, model, criterion, valid_loader, device):
    model.eval()
    with torch.no_grad():
        epoch_val_loss = 0
        for idx, (img, mask) in enumerate(valid_loader):
            img, mask = img.to(device), mask.to(device)
            score = model(img)
            epoch_val_loss += criterion(score, mask)
    return epoch_val_loss

And the code for epoch training is:

for epoch in range(1, max_epoch+1):
    scheduler.step()
    train_loss = train(epoch, model, optimizer, criterion, train_loader, device)
    val_loss = validate(epoch, model, criterion, valid_loader, device)
    print('Epoch:{}\{}\tTrain Loss:{:.4f}\tValid Loss: {:.4f}'.format(epoch, max_epoch, train_loss/len(train_loader), val_loss/len(valid_loader)))

But during the training, the train loss decreases normally while the validation loss is unstable and not decreasing:

Epoch:1\100     Train Loss:0.7167       Valid Loss: 0.6851
Epoch:2\100     Train Loss:0.6815       Valid Loss: 0.6733
Epoch:3\100     Train Loss:0.6688       Valid Loss: 0.6925
Epoch:4\100     Train Loss:0.6570       Valid Loss: 0.6920
Epoch:5\100     Train Loss:0.6424       Valid Loss: 0.6932
Epoch:6\100     Train Loss:0.6225       Valid Loss: 0.7024
Epoch:7\100     Train Loss:0.5986       Valid Loss: 0.7194
Epoch:8\100     Train Loss:0.5603       Valid Loss: 0.8340
Epoch:9\100     Train Loss:0.5142       Valid Loss: 1.0435
Epoch:10\100    Train Loss:0.4744       Valid Loss: 0.9237
Epoch:11\100    Train Loss:0.4260       Valid Loss: 1.1318
Epoch:12\100    Train Loss:0.3650       Valid Loss: 1.1251
Epoch:13\100    Train Loss:0.3258       Valid Loss: 0.9745
Epoch:14\100    Train Loss:0.3010       Valid Loss: 0.9804
Epoch:15\100    Train Loss:0.2743       Valid Loss: 0.9908
Epoch:16\100    Train Loss:0.2625       Valid Loss: 0.9386
Epoch:17\100    Train Loss:0.2586       Valid Loss: 0.9806
Epoch:18\100    Train Loss:0.2452       Valid Loss: 1.6770
Epoch:19\100    Train Loss:0.2447       Valid Loss: 1.3434
Epoch:20\100    Train Loss:0.2333       Valid Loss: 1.3110
Epoch:21\100    Train Loss:0.2161       Valid Loss: 0.9107
Epoch:22\100    Train Loss:0.1960       Valid Loss: 0.9699

Anyone can help with this issue?

1 Like

How large is your dataset?
You might add some regularization to your model (e.g. Dropout or weight decay) or add some data augmentation.
Also, do you split randomly or are you using any logic?

Thank you for the reply.

I have 400 image/target pair of size [224, 224].
The encoding part of my network is initialized from pretrained VGG16, while decoding part consists of 5 transposed convolution layer to reach the same dimension as input (I suppose these Conv2dTranspose layers are randomly initialized). Is my dataset large enough to train the model?

For data splitting, i use random split function from pytorch.

Your dataset is indeed quite small. I would try to add a lot of data augmentation and maybe remove some capacity of your model (smaller layers in the decoder part), if that’s possible.

1 Like

Thank you very much for the advice. I will try out those options.

After I reduce my encoding part to only the first 7 conv layers of vgg16, and added data augmentation, both training loss and validation loss decrease simultaneously. Thanks a lot for the help.

2 Likes