Bad segmentation masks when using model.eval()

Hi,

I am doing a hobby project in computer vision where I am using transfer learning to make the pretained image segmentation network deeplabv3 in torchvision detect sidewalks.

Unfortunately there seems to be something going wrong when I try to predict the output for some arbitrary images. The segmentation masks which look okay during training become really bad and strange.

Here is the code that loads the saved weights from training into the model for evaluation.

import torch
import torchvision

def createDeepLabv3(outputchannels=1):
    model = torchvision.models.segmentation.deeplabv3_resnet101(weights=torchvision.models.segmentation.DeepLabV3_ResNet101_Weights.DEFAULT)
    model.classifier = torchvision.models.segmentation.deeplabv3.DeepLabHead(2048, outputchannels)
    for param in model.backbone.parameters():
        param.requires_grad = False
    return model

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = createDeepLabv3()
model.load_state_dict(torch.load('/path_to_model/weights.pt'))
model.to(device)

Here is the code that visualises a prediction for a chosen image.

import matplotlib.pyplot as plt
import cv2
import numpy as np

img = cv2.imread('/path_to_img/img.png').transpose(2,0,1).reshape(1,3,480,640)
mask = cv2.imread('path_to_mask/mask.png'))
model.eval()
with torch.no_grad():
    a = model(torch.from_numpy(img).type(torch.FloatTensor).to(device)/255)

b = a['out'][0].cpu().detach().numpy().squeeze()
plt.figure(figsize=(10,10))
plt.subplot(131)
plt.imshow(img1[0,...].transpose(1,2,0))
plt.title('Image')
plt.axis('off')
plt.subplot(132)
plt.imshow(mask)
plt.title('Ground Truth')
plt.axis('off')
plt.subplot(133)
plt.imshow(b>np.mean(b))
plt.title('Segmentation Output')
plt.axis('off')
print(np.mean(b))

Here is the result when using ‘model.train()’ vs when using ‘model.eval()’

I have been stuck with this problem for a long time now, so if anyone knows what might be the cause of this I would be very thankful.