Outputs are different depending on the input size

Thank you for reading my post.

My question is, how can I mitigate the small numerical errors results from input size variety to accumulate through the layers?

This question may not be obvious. So I will explain it below with a figure illustrating this problem.

Background:
Because of the limit of numerical accuracy, the input size (shape) influences the outputs even though the expected outputs are mathematically identical. This is explained in the doc Numerical accuracy — PyTorch 2.4 documentation. There is a quote from the doc:

Then (A@B)[0] (the first element of the batched result) is not guaranteed to be bitwise identical to A[0]@B[0] (the matrix product of the first elements of the input batches) even though mathematically it’s an identical computation.’

And, I suppose that such errors are so small that you can essentially ignore. However, I noticed that such errors can grow larger as the input goes through multiple layers and become large enough to affect the model’s performance.
Experiment
I instantiated a simple model with 20 feed-forard layers with (or without) residual connections. I setup a random input with the size of (batchsize, 1024, 768). The output size is expected to be the same as the input.
Suppose I am now interested only in the last row of the input (-1, -1, :), I let the model compute in two ways:

  1. Pass the entire input, and then slice out the last row from the output.
  2. Pass only the last row of the input

I stored the intermediate outputs in each layer for the experiment, and compute the difference between the two strategies. The expected results are mathematically the same; however, they are slightly different as expained above. As you can see in the figure, the small numerical errors are linearly increasing.


Question
How can I mitigate such accumulating errors? Are there any good ways to prevent this?
I am acutually facing this problem in my model with 12 decoder-only transformer layers, which has many feed-forward layers with residual connections. This phenomenon is seemingly so prominent in my case that I cannot implement KV-cache (where I only use the last row of the input).
I will appreciate it very much if someone gives me any insights into this.

=== Code ===
The experiment, and the figure can be reproduced by the codes below:

import torch
import matplotlib.pyplot as plt

# ********************
# * Model Definition *
# ********************
class TestFeedForward(torch.nn.Module):
    """Simple residual layer (Linear -> GeLU -> Linear)"""
    def __init__(self, dim1, dim2):
        super().__init__()
        self.norm1 = torch.nn.LayerNorm(dim1)
        self.norm2 = torch.nn.LayerNorm(dim1)
        self.linear1 = torch.nn.Linear(dim1,dim2)
        self.gelu = torch.nn.GELU()
        self.linear2 = torch.nn.Linear(dim2,dim1)

    def forward(self, x, residual:bool=True):
        out = self.linear1(self.norm1(x))
        out = self.gelu(out)
        out = self.norm2(self.linear2(out))
        if residual:
            out = x + out
        return out

class TestMod(torch.nn.Module):
    """A test module with multiple feedforward layers with residual connections.
    """
    def __init__(self, n_layers:int, embed_dim:int):
        super().__init__()
        self.layers = torch.nn.ModuleList([TestFeedForward(embed_dim,embed_dim) for _ in range(n_layers)])
        # Parameter initialization
        for n, p in self.named_parameters():
            if "norm" in n and "weight" in n:
                torch.nn.init.constant_(p, 1)
            elif "bias" in n:
                torch.nn.init.constant_(p, 0)
            else:
                torch.nn.init.normal_(p, mean=0, std=0.2)
        # Initialize intermediate results
        self.register_buffer("checkpoints", None)
    def forward(self, x, residual:bool=True):
        out = x
        for layer in self.layers:
            # Forward
            out = layer(out, residual=residual)
            # Save the intermediate results as 'self.checkpoints'
            if self.checkpoints is None:
                self.checkpoints = out.unsqueeze(0)
            else:
                self.checkpoints = torch.cat([self.checkpoints, out.unsqueeze(0)])
        return out

# **********************
# * Conduct Experiment *
# **********************
# Instantiate a model
device = "cuda" if torch.cuda.is_available() else "cpu"
model = TestMod(n_layers=20, embed_dim=768).eval().to(device)
# Compare between different batchsizes
batchsize_list = [1, 8, 16]
fig, axes = plt.subplots(1,2, figsize=(10, 5))
for i, residual in enumerate([True, False]):
    for batchsize in batchsize_list:
        # Try forward computation with an entire batch
        b1 = torch.randn(size=(batchsize, 1024, 768),device=device)
        _ = model(b1, residual = residual)
        cp1 = model.checkpoints # <- this is the set of intermediate computation results
        model.checkpoints = None
        # Try forward computation only with the last rows (simulating KV-cache)
        b2 = b1[-1:, -1:, :].clone()
        _ = model(b2, residual = residual)
        cp2 = model.checkpoints
        model.checkpoints = None
        # Compute the difference and plot
        diff = (cp1[:, -1, -1, :] - cp2[:, -1, -1, :]).abs().mean(-1)
        axes[i].plot(diff.tolist(), label=f"Batchsize = {batchsize}")

axes[0].set_title("Residual connection on")
axes[1].set_title("Residual connection off")
axes[0].legend()
axes[1].legend()
fig.suptitle("Impact of variable input sizes on numerical errors: (1,1,768) vs (batchsize,1024,768)")
fig.supxlabel("Number of layers")
fig.supylabel("Numerical errors of intermediate results")
plt.show()

The absolute error increase could be expected if the magnitude of the activations would also grow. Comparing the same dtype against each other is also not giving you a lot of information since neither result is more accurate. Compare your intermediates against a wider dtype (e.g. float64) to check the errors.