Train loss is good but validation loss is large and not decreasing much. My model looks as follows
class DownSample_Block(nn.Module):
def __init__(self,in_ch,out_ch,down_sample=True,batch_norm=True):
super().__init__()
self.down_sample = down_sample
self.batch_norm = batch_norm
self.pool = nn.MaxPool2d(2)
self.bn = nn.BatchNorm2d(num_features = out_ch)
self.conv1 = nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, padding=1)
def forward(self,x):
if self.down_sample:
x = self.pool(x)
# conv ->bn-> relu
x = self.conv1(x)
if self.batch_norm:
x = self.bn(x)
x = F.relu(x)
x = self.conv2(x)
if self.batch_norm:
x = self.bn(x)
x = F.relu(x)
return x
class Upsample_Block(nn.Module):
def __init__(self,in_ch,out_ch,skip_ch):
super().__init__()
self.upsample = nn.ConvTranspose2d(in_ch,in_ch,kernel_size=2,stride=2)
# self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.bn = nn.BatchNorm2d(num_features=out_ch)
self.conv1 = nn.Conv2d(in_channels=in_ch+skip_ch, out_channels=out_ch, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, padding=1)
def forward(self,x,skip):
up = self.upsample(x)
out = torch.cat((skip,up),axis=1)
out = F.relu(self.bn(self.conv1(out)))
out = F.relu(self.bn(self.conv2(out)))
return out
class Unet(nn.Module):
def __init__(self):
super().__init__()
bins = 15
self.eblock1 = DownSample_Block(3,32,down_sample=False)
self.eblock2 = DownSample_Block(32,64)
self.eblock3 = DownSample_Block(64,128)
self.eblock4 = DownSample_Block(128,256)
self.eblock5 = DownSample_Block(256,512)
self.dblock2 = Upsample_Block(512,256,256)
self.dblock3 = Upsample_Block(256,128,128)
self.dblock4 = Upsample_Block(128,64,64)
self.dblock5 = Upsample_Block(64,32,32)
# distance
self.dist_conv = nn.Conv2d(32,bins,kernel_size=1)
#segmentation
self.seg_conv = nn.Conv2d(32+bins,2,kernel_size=1)
def forward(self,x):
d1 = self.eblock1(x)
# print(d1.shape)
d2 = self.eblock2(d1)
# print(d2.shape)
d3 = self.eblock3(d2)
# print(d3.shape)
d4 = self.eblock4(d3)
# print(d4.shape)
d5 = self.eblock5(d4)
# print(d5.shape)
u1 = self.dblock2(d5,d4)
# print(u1.shape)
u2 = self.dblock3(u1,d3) #d5
# print(u2.shape)
u3 = self.dblock4(u2,d2)
# print(u3.shape)
u4 = self.dblock5(u3,d1)
# print(u4.shape)
#distance out
out_dist = self.dist_conv(u4)
#segment out # relu(out_dist) + u4 out
out_relu = F.relu(out_dist)
out_concat = torch.cat([u4,out_relu],axis=1)
out_seg = self.seg_conv(out_concat)
# print(out_dist.shape,out_seg.shape)
return out_seg,out_dist
My loss function and distance transform function is
def distance_map_batch_v2(Y_batch,threshold=20,bins=15):
'''
Computes distance map of https://arxiv.org/abs/1709.05932 .
Threshold and number of bins are parameters explained in the paper.
Y_batch is the one hot encoded mask of the map with buildings and background. Mask for background is stored on channel 0 and Mask for buildings is stored on channel 1. Y_batch is a torch tensor
returns: torch tensor distance map
'''
Y_batch_dist = []
for i in range(len(Y_batch)):
msk = Y_batch[i].detach().cpu().numpy()
msk = msk.transpose(1,2,0)
distance_build=distance_transform_bf(msk[:,:,1],sampling=2)
distance_background=distance_transform_bf(msk[:,:,0],sampling=2)
distance_build=np.minimum(distance_build,threshold*(distance_build>0))
distance_background=np.minimum(distance_background,threshold*(distance_background>0))
distance=(distance_build-distance_background)
distance=(distance-np.amin(distance))/(np.amax(distance)-np.amin(distance)+1e-50)*(bins-1)
inp=torch.LongTensor(distance)
inp_ = torch.unsqueeze(inp, len(distance.shape))
one_hot = torch.FloatTensor(distance.shape[0], distance.shape[1], bins).zero_()
one_hot.scatter_(len(distance.shape), inp_, 1)
one_hot=np.asarray(one_hot)
one_hot = one_hot.transpose(2,0,1)
Y_batch_dist.append(one_hot)
return torch.FloatTensor(np.asarray(Y_batch_dist))
def cce(y_true,y_est):
'''
Loss used --> cross entropy
'''
y_true = y_true.contiguous().view(-1,y_true.size()[-1])
y_est = y_est.contiguous().view(-1,y_true.size()[-1])
y_true_flat = y_true.max(-1)[1]
# print(y_est.shape,y_true_flat.shape)
loss_func = nn.CrossEntropyLoss()
loss = loss_func(y_est,y_true_flat)#be careful inverse order of arguments
return loss
class MulticlassLoss(nn.Module):
def __init__(self, weight=None, size_average=True):
super(MulticlassLoss, self).__init__()
def forward(self, pred_mask,pred_dist, y_batch, y_dist, smooth=1):
#comment out if your model contains a sigmoid or equivalent activation layer
bce_weight = 0.5
loss_seg = cce(y_batch,pred_mask)
# print(loss_seg)
loss_dist = cce(y_dist,pred_dist)
# print(loss_dist)
loss = (bce_weight*loss_seg)+(bce_weight*loss_dist)
return loss
Training code is
checkpoint_path = 'model2/chkpoint_'
best_model_path = 'model2/bestmodel.pt'
# model = model.to(device)
epochs = 50
criterion = MulticlassLoss()
learning_rate = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
dice_coef = DiceCoef()
valid_loss_min = 3.95275 # some random value
train_loss,val_loss = [],[]
train_iou,val_iou = [],[]
train_coef,val_coef = [],[]
for epoch in range(epochs):
print('Epoch {}/{}'.format(epoch + 1, epochs))
start_time = time.time()
model.train()
running_train_loss = []
running_train_score = []
running_train_coef = []
for image,mask in train_loader:
image = image.to(device,dtype=torch.float)
mask = mask.to(device,dtype=torch.float)
true_dist = distance_map_batch_v2(mask)
true_dist = true_dist.to(device,dtype=torch.float)
pred_mask,pred_dist = model.forward(image) # forward propogation
loss = criterion(pred_mask, pred_dist, mask, true_dist)
# print(pred_mask.shape,mask.shape)
score = iou_batch(pred_mask,mask)
coef = dice_coef(pred_mask,mask)
optimizer.zero_grad() # setting gradient to zero
loss.backward()
optimizer.step()
running_train_loss.append(loss.item())
running_train_score.append(score)
running_train_coef.append(coef.item())
running_val_loss = []
running_val_score = []
running_val_coef = []
model.eval()
with torch.no_grad():
for image,mask in val_loader:
image = image.to(device,dtype=torch.float)
mask = mask.to(device,dtype=torch.float)
true_dist = distance_map_batch_v2(mask)
true_dist = true_dist.to(device,dtype=torch.float)
pred_mask,pred_dist = model.forward(image)
loss = criterion(pred_mask, pred_dist, mask, true_dist)
score = iou_batch(pred_mask,mask)
coef = dice_coef(pred_mask,mask)
running_val_loss.append(loss.item())
running_val_score.append(score)
running_val_coef.append(coef.item())
epoch_train_loss,epoch_train_score,epoch_train_coef = np.mean(running_train_loss) ,np.mean(running_train_score),np.mean(running_train_coef)
print('Train loss : {} iou : {}, dice coef: {}'.format(epoch_train_loss,epoch_train_score,epoch_train_coef))
train_loss.append(epoch_train_loss)
train_iou.append(epoch_train_score)
train_coef.append(epoch_train_coef)
epoch_val_loss,epoch_val_score,epoch_val_coef = np.mean(running_val_loss),np.mean(running_val_score),np.mean(running_val_coef)
print('Validation loss : {} iou : {}, dice coef: {}'.format(epoch_val_loss,epoch_val_score,epoch_val_coef))
val_loss.append(epoch_val_loss)
val_iou.append(epoch_val_score)
val_coef.append(epoch_val_coef)
# create checkpoint variable and add important data
checkpoint = {
'epoch': epoch + 1,
'valid_loss_min': epoch_val_loss,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
}
# save checkpoint
save_ckp(checkpoint, False, checkpoint_path, best_model_path)
## TODO: save the model if validation loss has decreased
if epoch_val_loss <= valid_loss_min:
print('Validation loss decreased ({:.6f} --> {:.6f}). Saving model ...'.format(valid_loss_min,epoch_val_loss))
# save checkpoint as best model
save_ckp(checkpoint, True, checkpoint_path, best_model_path)
valid_loss_min = epoch_val_loss
time_elapsed = time.time() - start_time
print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
Results is
Epoch 1/50
Train loss : 1.636072209051677 iou : 0.2590805115404147, dice coef: 0.4123131136809077
Validation loss : 3.644066361586253 iou : 0.3961057460856792, dice coef: 0.5682796835899353
Validation loss decreased (3.952750 --> 3.644066). Saving model …
32m 49s
Epoch 2/50
Train loss : 1.2269017398357391 iou : 0.24557374558415218, dice coef: 0.39523599807705195
Validation loss : 3.8924281001091003 iou : 0.4135372376490164, dice coef: 0.5854614973068237
14m 48s
Epoch 3/50
Train loss : 1.1182671712977545 iou : 0.23921578082771644, dice coef: 0.3875904355730329
Validation loss : 3.321150771776835 iou : 0.39828164579205444, dice coef: 0.5704469372828801
Validation loss decreased (3.644066 --> 3.321151). Saving model …
14m 55s
Epoch 4/50
Train loss : 1.047594209228243 iou : 0.2370423747555327, dice coef: 0.3851519699607577
Validation loss : 3.4572733004887897 iou : 0.413889831623608, dice coef: 0.5856322149435679
14m 38s
Epoch 5/50
Train loss : 0.9927580939871924 iou : 0.23708825814195278, dice coef: 0.38523539411170143
Validation loss : 3.2639930764834086 iou : 0.3895362960387864, dice coef: 0.560992560784022
Validation loss decreased (3.321151 --> 3.263993). Saving model …
14m 32s
Epoch 6/50
Train loss : 0.9615498576845442 iou : 0.23610517876921325, dice coef: 0.38409997075796126
Validation loss : 3.03681640625 iou : 0.3960668315457316, dice coef: 0.5678311238686243
Validation loss decreased (3.263993 --> 3.036816). Saving model …
14m 28s
Epoch 7/50
Train loss : 0.9389585907970156 iou : 0.23596643507862994, dice coef: 0.3839401804975101
Validation loss : 2.9654738306999207 iou : 0.38705827275984567, dice coef: 0.5584435482819875
Validation loss decreased (3.036816 --> 2.965474). Saving model …
14m 22s
Epoch 8/50
Train loss : 0.9142322753156934 iou : 0.2337625844423078, dice coef: 0.38102387764624185
Validation loss : 2.775194569428762 iou : 0.37924780132712754, dice coef: 0.5504745692014694
Validation loss decreased (2.965474 --> 2.775195). Saving model …
14m 12s
Epoch 9/50
Train loss : 0.8896246692964009 iou : 0.22980133446432147, dice coef: 0.3760204364146505
Validation loss : 2.7394471764564514 iou : 0.3958687851408545, dice coef: 0.5675292551517487
Validation loss decreased (2.775195 --> 2.739447). Saving model …
14m 23s
Epoch 10/50
Train loss : 0.8742967473609107 iou : 0.22970163538591346, dice coef: 0.375763281754085
Validation loss : 2.7312774141629537 iou : 0.43684185002205045, dice coef: 0.6085633635520935
Validation loss decreased (2.739447 --> 2.731277). Saving model …
14m 32s
Epoch 11/50
Train loss : 0.8534701717751366 iou : 0.2294650052856951, dice coef: 0.375523764533656
Validation loss : 2.4273023923238117 iou : 0.3948336292033707, dice coef: 0.5670865674813589
Validation loss decreased (2.731277 --> 2.427302). Saving model …
14m 28s
Epoch 12/50
Train loss : 0.8340641353811536 iou : 0.22669237089401914, dice coef: 0.372151409302439
Validation loss : 2.761715757846832 iou : 0.3619836216192266, dice coef: 0.5321707213918369
14m 46s
Epoch 13/50
Train loss : 0.8210224441119602 iou : 0.2280403249532117, dice coef: 0.37345540864127025
Validation loss : 2.6830695708592733 iou : 0.4035514459610498, dice coef: 0.575622484087944
14m 59s
Epoch 14/50
Train loss : 0.8059026994875499 iou : 0.2248244423588186, dice coef: 0.3693577834538051
Validation loss : 2.1368691086769105 iou : 0.412322404646021, dice coef: 0.5850987017154694
Validation loss decreased (2.427302 --> 2.136869). Saving model …
15m 0s
Epoch 15/50
Train loss : 0.7877428701945713 iou : 0.22854895431555933, dice coef: 0.3743239802973611
Validation loss : 2.2300844828287762 iou : 0.38135859358075624, dice coef: 0.5536836306254069
15m 10s
Epoch 16/50
Train loss : 0.7829681926539966 iou : 0.22874669823879584, dice coef: 0.3749655529856682
Validation loss : 2.4254708528518676 iou : 0.41929319700766815, dice coef: 0.5917253951231639
15m 20s
Epoch 17/50
Train loss : 0.7644439024584634 iou : 0.23029824865898385, dice coef: 0.3765100053378514
Validation loss : 2.4135544896125793 iou : 0.4461470567967697, dice coef: 0.6174438158671062
15m 24s
Epoch 18/50
Train loss : 0.7528764850326947 iou : 0.22475351346286607, dice coef: 0.36964563727378846
Validation loss : 2.076960837841034 iou : 0.46038506970180504, dice coef: 0.6315884411334991
Validation loss decreased (2.136869 --> 2.076961). Saving model …
15m 25s
Epoch 19/50
Train loss : 0.7382002962487084 iou : 0.22603396174056087, dice coef: 0.3714781603642872
Validation loss : 2.2174523651599882 iou : 0.45246894388506703, dice coef: 0.623819363117218
14m 40s
Epoch 20/50
Train loss : 0.726068066912038 iou : 0.22376600579043568, dice coef: 0.3686292886734009
Validation loss : 1.9367382168769836 iou : 0.39958022276006816, dice coef: 0.5725013802448908
Validation loss decreased (2.076961 --> 1.936738). Saving model …
13m 45s
Epoch 21/50
Train loss : 0.714294318641935 iou : 0.22446068206685127, dice coef: 0.36995240322181155
Validation loss : 2.322433396180471 iou : 0.4313971246092888, dice coef: 0.6031914939483006
13m 46s
Epoch 22/50
Train loss : 0.6834765095795904 iou : 0.21772951261682544, dice coef: 0.3606676868030003
Validation loss : 2.0876523753007254 iou : 0.45842205361599536, dice coef: 0.6297118236621221
13m 40s
Epoch 23/50
Train loss : 0.6793745417680059 iou : 0.21839416961569164, dice coef: 0.36152691223791666
Validation loss : 2.267058895031611 iou : 0.47621269970422203, dice coef: 0.6458630830049514
13m 40s
…
Is my model overfitting. What might be reason here? Please someone help to solve this issue.