Batch size changes (Linear) layer output values

I’ve noticed that one of my model changes the output/prediction if I change the batch size (number of sequences). After some tests, I’ve found out that nn.Linear layers give slightly different results for the same input values if the batch size is changed.

Running the following code:

import torch
import torch.nn as nn

import numpy as np
import random as rnd

seed = 101

rnd.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.use_deterministic_algorithms(True)

torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

f = nn.Linear(1024, 64).cuda()
x = torch.ones(16, 300, 1024).float().cuda()

for n in range(x.shape[0]):
    y = f(x)[:n+1]
    z = f(x[:n+1])

    print(torch.equal(y, z), torch.abs(y - z).mean().item())

I get:

False 2.0721927285194397e-08
False 1.862645149230957e-08
False 1.862645149230957e-08
False 1.862645149230957e-08
False 1.862645149230957e-08
False 1.862645149230957e-08
False 1.5133991837501526e-08
False 1.5133991837501526e-08
False 1.5133991837501526e-08
False 1.5133991837501526e-08
False 1.5133991837501526e-08
False 1.5133991837501526e-08
True 0.0
True 0.0
True 0.0
True 0.0

The same is also true for nn.Conv1d layers, though seemingly not for nn.Conv2d layers. Setting torch.backends.cudnn.enabled = False seemingly fixes the issue for nn.Conv1d, but not for nn.Linear layers. Finally, this effect only appears on GPU and not if those calculations are run on CPU.

Is this expected behavior (due to some float precision issues) or a bug?
I’m running pytorch 1.9.0 (py3.8_cuda11.1_cudnn8_0) if that makes a difference.

This would be expected as different batch-sizes will likely dispatch to different algorithms that can have different numerical behavior (e.g., due to differences in how reductions are ordered).

I observe similar behavior with float32 on sm_86, but e.g., changing the workspace size via CUBLAS_WORKSPACE_CONFIG=:16:8 seems to reduce the number of batch sizes that produce slightly different results, and using float64 results in equal comparisons throughout. In general however, there aren’t any guarantees for identical behavior for different batch sizes—deterministic is only intended to guarantee results for the same batch size.

1 Like

Thank you for the quick answer!

Setting CUBLAS_WORKSPACE_CONFIG to :16:8 (instead of my previous :4096:8) indeed reduced the issue quite a lot (at least for nn.Linear). Going from float32 to float64 reduced the numerical differences by several orders of magnitude, but did not eliminate them; at least on my hardware (RTX 3060).

As far as I understood, I would need to always enforce the same batch size (same as used during training), even when predicting only a single sample (i.e. add dummy samples). Otherwise I risk that:
a) I get worse performance, because the model makes different calculations as it did during training.
b) I get different results for the same sample, depending on how many other samples I predict in the same batch.

No, usually you wouldn’t expect a) and b), since the limited numerical precision is expected and neither value is the “true” value as they depend which values are representable in float32.
You can read more about it in this Wikipedia article.
If your model is sensitive to the expected numerical noise, you might need to use a wider dtype, such as float64.