Systematic Study of Numerical Instability Caused by Extreme Weight Initialization in PyTorch

Problem Description

I’m investigating how weight initialization affects training stability in neural networks. Two architecturally similar models show drastically different behaviors:

# Stable version
class SimpleModel(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, output_dim=2):
        super().__init__()
        self.layer1 = nn.Linear(input_dim, hidden_dim)  # Default init
        self.layer2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, inputs):
        with torch.amp.autocast(enabled=False, device_type="cuda"):
            x = self.layer1(inputs)
        x = self.layer2(x)
        return x

# Unstable version
class UnstableModel(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, output_dim=2):
        super().__init__()
        self.layer1 = nn.Linear(input_dim, hidden_dim)
        self.layer2 = nn.Linear(hidden_dim, output_dim)
        
        # Extreme initialization
        nn.init.uniform_(self.layer1.weight, -100, 100)  # Range >> typical
        nn.init.zeros_(self.layer1.bias)
        nn.init.uniform_(self.layer2.weight, -100, 100)
        nn.init.zeros_(self.layer2.bias)

    def forward(self, inputs):
        x = self.layer1(inputs)  # No stability measures
        x = self.layer2(x)
        return x

When comparing two architecturally identical models with different initializations, we observe outputs differing by 4 orders of magnitude:

Input = [0.1, 0.2, ..., 1.0]  # 10D features

# Output comparison
Stable_model = [-0.13, -0.05]      # Magnitude: 10^0
Extreme_init_model = [28528, 55630] # Magnitude: 10^4

Questions:

  1. Why does extreme weight initialization (e.g., ∈[-100,100]) cause exponential output growth?
  2. How to identify the layer where instability first occurs?
  3. What does this reveal about model robustness?

Hi Xiaoyao!

First comment:

If you pass inputs of order one through a Linear whose weights are of order 10^2, you
would expect to get outputs of order 10^2. Passing those inputs through two order-10^2
Linears would naturally produce outputs of order 10^4.

Second comment:

You asked about training stability, but you haven’t told us anything about how (or whether)
you trained your model.

However, gradient-descent optimization routinely becomes unstable when the steps taken
are too large. The size of a step is the size of the gradient times the learning rate. If the
resulting step is too large, you jump to a place where the weights are worse and the
gradient is larger. Then you take a bigger jump to where the weights are even worse and
the gradient is even larger and the training systematically diverges.

When you initialize the weights to large values, the gradients will start out large and you
risk this kind of divergence / instability. You can compensate for this by starting with a
smaller learning rate, say one that is smaller by a factor of 10^-4.

(Note, if you systematically increase the learning rate with which you train the stable
version of your model, you will be able to make its training similarly unstable.)

This doesn’t really say anything about the robustness of the model itself. It’s the mismatch
between the learning rate and the perversely-large initial values of the model’s weights that
is causing the unstable training.

(In your case, you should be able to start training your “unstable” model with a small
learning rate and after the weights have evolved to more reasonable values, continue
training with a larger, more typical learning rate.)

Best.

K. Frank