Identical outputs for different inputs when model in eval() mode

I am working on an image classification model that classifies emotion of the face in the image. I found that the model produces near identical outputs for different training examples when set to eval() mode, but this does not occur for the train() mode. Why does this happen?

Specifications :

  • MacOS
  • torch version 2.0.0
  • torchvision version 0.15.1

Code:

import torch
from torch import nn
from torchvision import transforms
import os
from PIL import Image

BATCH_SIZE = 32
LEARNING_RATE = 0.001
EPOCHS = 10
DEV = torch.device("cpu")

class EmotionModel(nn.Module):

    def __init__(self, p=0.25):
        super(EmotionModel, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=64, kernel_size=(3, 3), stride=1, padding=1),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2)),
            nn.Dropout(p),

            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), stride=1, padding=1),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2)),
            nn.Dropout(p),

            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), stride=1, padding=1),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2)),
            nn.Dropout(p),

            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3, 3), stride=1, padding=1),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2)),
            nn.Dropout(p),
        )

        self.dense_layers = nn.Sequential(
            nn.Linear(4608, 512),
            nn.ReLU(),
            nn.Dropout(p),

            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(p),

            nn.Linear(256, 7)
        )

    def forward(self, x):
        x = self.conv_layers(x)
        return self.dense_layers(x.flatten(1))

def load_images(path):
    T = transforms.Compose([
        transforms.Resize((48, 48)),
        transforms.ToTensor()
        ])
    images = [Image.open(os.path.join(path, img_name)) for img_name in os.listdir(path)]
    images = [T(img).unsqueeze(0) for img in images]
    return torch.concat(images, 0)
    
model = EmotionModel().to(DEV)
model.eval()
images = load_images('test-images').to(DEV)
output = model(images)
print(output)

Output with model.eval():

tensor([[ 0.0019, -0.0317,  0.0251, -0.0427, -0.0861,  0.0278,  0.0064],
        [ 0.0018, -0.0314,  0.0251, -0.0426, -0.0862,  0.0280,  0.0063],
        [ 0.0018, -0.0314,  0.0251, -0.0425, -0.0862,  0.0280,  0.0064],
        [ 0.0017, -0.0314,  0.0251, -0.0426, -0.0862,  0.0280,  0.0064],
        [ 0.0017, -0.0313,  0.0250, -0.0425, -0.0862,  0.0280,  0.0063]],
       grad_fn=<AddmmBackward0>)

Output without model.eval():

tensor([[-0.0598,  0.0206, -0.0462,  0.0168,  0.0156, -0.0317, -0.0510],
        [-0.0608,  0.0100, -0.0478,  0.0172,  0.0198, -0.0265, -0.0467],
        [-0.0532,  0.0180, -0.0511,  0.0146,  0.0189, -0.0234, -0.0428],
        [-0.0425,  0.0023, -0.0447,  0.0217,  0.0201, -0.0253, -0.0598],
        [-0.0532,  0.0014, -0.0529,  0.0256,  0.0113, -0.0312, -0.0498]],
       grad_fn=<AddmmBackward0>)

The difference in output between eval() and train() modes is due to dropout layers, which are active only during training to prevent overfitting. In eval() mode, dropout layers are disabled, resulting in more consistent outputs across examples. In train() mode, the active dropout layers introduce variability in outputs. Always use model.eval() for evaluation and predictions to ensure consistent and accurate results.

I’ve seen this effect sometimes when the model was “saturating” the outputs (and sometimes only returned the last bias term as the penultimate activation was zeroed out by a relu).
In these runs it helped to use a smaller model but I also guess that e.g. different parameter initialization could help (I didn’t experiment with it).

@AbdulsalamBande’s answer about the dropout behavior in train and eval mode is correct, but doesn’t explain why the eval outputs are approx. static for different inputs.

That does seem to be the case, since when I try to evaluate without training, the outputs are different, which would suggest that during training, the weights are coverging to make the activations zero or negative. I tried a few different weight init methods and making the model a bit smaller and that does seem to help. Thanks.

1 Like