I am training a FCN-alike model for semantic segmentation. The output of model is [batch, 2, 224, 224], and the target is [batch, 224, 224]. I used nn.CrossEntropyLoss() as the loss function.
As for the training process, I randomly split my dataset into train and validation with [0.8, 0.2] random split.
The code for training and validation are as follows:
def train(epoch, model, optimizer, criterion, train_loader, device):
gc.collect()
model.train()
train_loss = 0
for batch_idx, (img, mask) in enumerate(train_loader):
img, mask = img.to(device), mask.to(device)
optimizer.zero_grad()
score = model(img)
loss = criterion(score, mask)
train_loss += loss
loss.backward()
optimizer.step()
return train_loss
def validate(epoch, model, criterion, valid_loader, device):
model.eval()
with torch.no_grad():
epoch_val_loss = 0
for idx, (img, mask) in enumerate(valid_loader):
img, mask = img.to(device), mask.to(device)
score = model(img)
epoch_val_loss += criterion(score, mask)
return epoch_val_loss
And the code for epoch training is:
for epoch in range(1, max_epoch+1):
scheduler.step()
train_loss = train(epoch, model, optimizer, criterion, train_loader, device)
val_loss = validate(epoch, model, criterion, valid_loader, device)
print('Epoch:{}\{}\tTrain Loss:{:.4f}\tValid Loss: {:.4f}'.format(epoch, max_epoch, train_loss/len(train_loader), val_loss/len(valid_loader)))
But during the training, the train loss decreases normally while the validation loss is unstable and not decreasing:
Epoch:1\100 Train Loss:0.7167 Valid Loss: 0.6851
Epoch:2\100 Train Loss:0.6815 Valid Loss: 0.6733
Epoch:3\100 Train Loss:0.6688 Valid Loss: 0.6925
Epoch:4\100 Train Loss:0.6570 Valid Loss: 0.6920
Epoch:5\100 Train Loss:0.6424 Valid Loss: 0.6932
Epoch:6\100 Train Loss:0.6225 Valid Loss: 0.7024
Epoch:7\100 Train Loss:0.5986 Valid Loss: 0.7194
Epoch:8\100 Train Loss:0.5603 Valid Loss: 0.8340
Epoch:9\100 Train Loss:0.5142 Valid Loss: 1.0435
Epoch:10\100 Train Loss:0.4744 Valid Loss: 0.9237
Epoch:11\100 Train Loss:0.4260 Valid Loss: 1.1318
Epoch:12\100 Train Loss:0.3650 Valid Loss: 1.1251
Epoch:13\100 Train Loss:0.3258 Valid Loss: 0.9745
Epoch:14\100 Train Loss:0.3010 Valid Loss: 0.9804
Epoch:15\100 Train Loss:0.2743 Valid Loss: 0.9908
Epoch:16\100 Train Loss:0.2625 Valid Loss: 0.9386
Epoch:17\100 Train Loss:0.2586 Valid Loss: 0.9806
Epoch:18\100 Train Loss:0.2452 Valid Loss: 1.6770
Epoch:19\100 Train Loss:0.2447 Valid Loss: 1.3434
Epoch:20\100 Train Loss:0.2333 Valid Loss: 1.3110
Epoch:21\100 Train Loss:0.2161 Valid Loss: 0.9107
Epoch:22\100 Train Loss:0.1960 Valid Loss: 0.9699
Anyone can help with this issue?