Model .eval() predicts all the same (or nearly same) outputs

Hi!
I’m training the changed DETR transformer model on the custom dataset.
During .train() mode the model is doing normal predictions (all different), but if I run .eval() mode for evaluation - the outputs of the model are all same (or almost same).
What can be the problem?
The LR is 1e-5 and the out layer is linear.
Will appreciate any advice!

@ptrblck I watched many posts you were helping people with the similar problem - can you have a look in my case as well? :slight_smile:

During the evaluation (and after calling model.eval()) some layers change their behavior, e.g. batchnorm layers are using the running stats to normalize the input activations and dropout layers are disabled.
If the accuracy of the model decreases significantly during the eval phase, noisy batchnorm stats could be the cause and you could either try to increase the batch size during training or change the momentum of the batchnorm layers to smooth the updates.

Hi @ptrblck
Thank you for your answer. I’m actually using nn.LayerNorm instead of nn.BatchNorm layers as in original DETR code. So I guess the batch size for me doesn’t matter and also I can’t change momentum for those layers. Am I correct?
My dataset is about 1000 samples, the batch size now is set to 4 and I’m training on 3 GPUs at a time.
Do you have any other suggestions?

Ah OK, that’s interesting as these layers shouldn’t have any running stats registered as buffers.
In any case, you should check which layers could change their behavior between training and validation runs. In case you are using custom layers, check if they are using their internal self.training attribute to change the forward pass.

There are used custom TransformerEncoder and TransformerDecoder layers, but I can’t find any self.training parameters passed in them. Let me post here the structure:

class TransformerEncoderLayer(nn.Module):

def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
             activation="relu", normalize_before=False):
    super().__init__()
    self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
    # Implementation of Feedforward model
    self.linear1 = nn.Linear(d_model, dim_feedforward)
    self.dropout = nn.Dropout(dropout)
    self.linear2 = nn.Linear(dim_feedforward, d_model)

    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)
    self.dropout1 = nn.Dropout(dropout)
    self.dropout2 = nn.Dropout(dropout)

    self.activation = _get_activation_fn(activation)
    self.normalize_before = normalize_before

class TransformerDecoderLayer(nn.Module):

def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
             activation="relu", normalize_before=False):
    super().__init__()
    self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
    self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
    # Implementation of Feedforward model
    self.linear1 = nn.Linear(d_model, dim_feedforward)
    self.dropout = nn.Dropout(dropout)
    self.linear2 = nn.Linear(dim_feedforward, d_model)

    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)
    self.norm3 = nn.LayerNorm(d_model)
    self.dropout1 = nn.Dropout(dropout)
    self.dropout2 = nn.Dropout(dropout)
    self.dropout3 = nn.Dropout(dropout)

    self.activation = _get_activation_fn(activation)
    self.normalize_before = normalize_before

So I have no idea of what could cause the differences between Training and Evaluation parts. Can it be because of the Dropout usage? The dropout value is set to 0.1 now.

Yes, dropout will be disabled during evaluation as previously mentioned and would thus change the output. I don’t know if it could be responsible for your issue, but if no other layers change their behavior I would guess so. Could you disable dropout also during training and see if now the training output would also be static?

I did some experiments as you suggested!

  1. TRAINING mode with default dropout = 0.1:
    Transformer output:
    tensor([[[-1.2072, -0.0247, 1.4827, …, 1.9502, -1.0014, -0.0777],
    [-1.0472, 1.0473, -0.1111, …, 1.9880, -0.7706, -0.9534],
    [-0.9117, 0.6312, 1.5603, …, 1.7538, -1.7898, -0.0764],
    …,
    [-1.5139, 0.8651, 1.3959, …, 1.4869, -0.7368, -0.9495],
    [-1.3713, 0.1539, 1.1079, …, 1.7069, 1.0330, -0.5884],
    [-0.9833, 0.9340, 1.4092, …, 2.1241, -0.9530, -0.0743]]],
    device=‘cuda:0’, grad_fn=)
    And final output of the model:
    {‘pred_logits’: tensor([[-0.0081, 0.2497, -0.3607, …, -0.0395, -0.2595, -0.0823],
    [ 0.0760, 0.3130, -0.2936, …, 0.1302, -0.7095, -0.1434],
    [ 0.0592, 0.2099, -0.3150, …, 0.1744, -0.8228, -0.6104],
    …,
    [ 0.0553, 0.0501, -0.2397, …, 0.2563, -0.5291, -0.3276],
    [-0.0884, -0.1168, -0.1549, …, 0.1795, -0.1136, 0.0661],
    [ 0.3371, 0.1840, -0.7856, …, 0.2078, -0.0932, -0.4215]],
    device=‘cuda:0’, grad_fn=)}

  2. TRAINING mode with dropout = 0.0 - can I consider it as DROPOUT is off, right?
    tensor([[[-0.8847, 0.8008, 1.3410, …, 1.9398, -0.9303, -0.7386],
    [-0.9201, 0.8979, 1.4128, …, 1.9263, -0.8656, -0.8076],
    [-0.8593, 0.8811, 1.3582, …, 1.8995, -0.9350, -0.7555],
    …,
    [-0.9158, 0.9414, 1.3551, …, 1.8739, -0.8665, -0.9009],
    [-0.8818, 0.8018, 1.3489, …, 1.9480, -0.9242, -0.8278],
    [-0.8603, 0.7642, 1.3655, …, 2.0025, -0.9490, -0.8010]]],
    device=‘cuda:0’, grad_fn=)
    {‘pred_logits’: tensor([[-0.1241, 0.0806, -0.3436, …, 0.2659, -0.3817, -0.2606],
    [-0.0993, 0.0815, -0.3826, …, 0.2688, -0.3978, -0.2872],
    [-0.1544, 0.0542, -0.3570, …, 0.2584, -0.3570, -0.2809],
    …,
    [-0.0540, 0.0955, -0.3878, …, 0.2575, -0.3749, -0.3109],
    [-0.1354, 0.0464, -0.3761, …, 0.2357, -0.3562, -0.2966],
    [-0.1206, 0.0465, -0.3985, …, 0.2530, -0.3951, -0.2766]],
    device=‘cuda:0’, grad_fn=)}

  3. EVAL mode with or without dropout doesn’t matter as it’s off:
    tensor([[[-0.8847, 0.8008, 1.3410, …, 1.9398, -0.9303, -0.7386],
    [-0.9201, 0.8979, 1.4128, …, 1.9263, -0.8656, -0.8076],
    [-0.8593, 0.8811, 1.3582, …, 1.8995, -0.9350, -0.7555],
    …,
    [-0.9158, 0.9414, 1.3551, …, 1.8739, -0.8665, -0.9009],
    [-0.8818, 0.8018, 1.3489, …, 1.9480, -0.9242, -0.8278],
    [-0.8603, 0.7642, 1.3655, …, 2.0025, -0.9490, -0.8010]]],
    device=‘cuda:0’, grad_fn=)
    {‘pred_logits’: tensor([[-0.1241, 0.0806, -0.3436, …, 0.2659, -0.3817, -0.2606],
    [-0.0993, 0.0815, -0.3826, …, 0.2688, -0.3978, -0.2872],
    [-0.1544, 0.0542, -0.3570, …, 0.2584, -0.3570, -0.2809],
    …,
    [-0.0540, 0.0955, -0.3878, …, 0.2575, -0.3749, -0.3109],
    [-0.1354, 0.0464, -0.3761, …, 0.2357, -0.3562, -0.2966],
    [-0.1206, 0.0465, -0.3985, …, 0.2530, -0.3951, -0.2766]],
    device=‘cuda:0’, grad_fn=)}

All of 1-3 were done using training data (same sample).
So 2 and 3 have identical outputs, which are very similar to each other in each column.
It means the problem of .eval() for me are Dropout layers…can it happen because of my data? F.ex. should I scale the dataset?
Then the question is how can I overcome this issue?

Hi @ptrblck can you please share ideas, if any? :slight_smile:

Hi,
I have the same problem as @KhanMar . I encountered the same output in “.eval()” evaluation mode. My model also includes a transformer. I think this bug is in the transformer itself.

I have encountered the same issue as well. I trained a model with dropout p=0 in the transformer encoder block. Later, if I load the weights, I get correct results from train() with torch.no_grad(), and incorrect results that are uniform regardless of input if I use eval().

I noticed that since 1.12, there were some big structural changes in the transformer code. I will try again using 1.11 and report back.

I believe the bug is here:

        why_not_sparsity_fast_path = ''
        ...
        elif self.training:
            why_not_sparsity_fast_path = "training is enabled"
        ...

            if (not why_not_sparsity_fast_path) and (src_key_padding_mask is not None):
                convert_to_nested = True
                output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)
                src_key_padding_mask_for_layers = None

There is a long set of conditionals that will create a different behavior, which is likely to reproduce differently under many circumstances. If you look at the code, why_not_sparsity_fast_path is always set during training, but not for eval.

The eval mode performs a modification of output using src_key_padding_mask. If you are not using src key padding mask, you will not observe the behavior, but if you set it to None when you use eval(), then it works as expected (but without src key padding mask).

This big conditional for why_not_sparsity_fast_path is not used in version 1.11, so I expect that rolling back will fix the issue.

Edit: I can confirm that loading my weights using 1.11 fixes the issue for me.

Could you create an issue on GitHub with your explanation, please, so that the code owners could take a look at it?

1 Like

I’ve opened an issue here:

Thanks so much for opening a defect on this issue. We’re trying to recreate the problem, but we have not been able to create a test cases that produces incorrect results in the presence of src_key_padding_mask.Thus, I wanted to follow up in on Rui’s request for a repro for this issue at TransformerEncoder src_key_padding_mask does not work in eval() · Issue #86120 · pytorch/pytorch · GitHub

cc: @dleemiller

I added a code snippet that I think narrows down the issue.

I think that if the src_key_padding_mask is supplied with a boolean mask, then it resolves this issue; however, using another type (long, float) appears to be problematic.

This is super helpful. Thanks so much for narrowing this down, and providing a repro.