I’m investigating how weight initialization affects training stability in neural networks. Two architecturally similar models show drastically different behaviors:
# Stable version
class SimpleModel(nn.Module):
def __init__(self, input_dim, hidden_dim=128, output_dim=2):
super().__init__()
self.layer1 = nn.Linear(input_dim, hidden_dim) # Default init
self.layer2 = nn.Linear(hidden_dim, output_dim)
def forward(self, inputs):
with torch.amp.autocast(enabled=False, device_type="cuda"):
x = self.layer1(inputs)
x = self.layer2(x)
return x
# Unstable version
class UnstableModel(nn.Module):
def __init__(self, input_dim, hidden_dim=128, output_dim=2):
super().__init__()
self.layer1 = nn.Linear(input_dim, hidden_dim)
self.layer2 = nn.Linear(hidden_dim, output_dim)
# Extreme initialization
nn.init.uniform_(self.layer1.weight, -100, 100) # Range >> typical
nn.init.zeros_(self.layer1.bias)
nn.init.uniform_(self.layer2.weight, -100, 100)
nn.init.zeros_(self.layer2.bias)
def forward(self, inputs):
x = self.layer1(inputs) # No stability measures
x = self.layer2(x)
return x
When comparing two architecturally identical models with different initializations, we observe outputs differing by 4 orders of magnitude:
If you pass inputs of order one through a Linear whose weights are of order 10^2, you
would expect to get outputs of order 10^2. Passing those inputs through two order-10^2 Linears would naturally produce outputs of order 10^4.
Second comment:
You asked about training stability, but you haven’t told us anything about how (or whether)
you trained your model.
However, gradient-descent optimization routinely becomes unstable when the steps taken
are too large. The size of a step is the size of the gradient times the learning rate. If the
resulting step is too large, you jump to a place where the weights are worse and the
gradient is larger. Then you take a bigger jump to where the weights are even worse and
the gradient is even larger and the training systematically diverges.
When you initialize the weights to large values, the gradients will start out large and you
risk this kind of divergence / instability. You can compensate for this by starting with a
smaller learning rate, say one that is smaller by a factor of 10^-4.
(Note, if you systematically increase the learning rate with which you train the stable
version of your model, you will be able to make its training similarly unstable.)
This doesn’t really say anything about the robustness of the model itself. It’s the mismatch
between the learning rate and the perversely-large initial values of the model’s weights that
is causing the unstable training.
(In your case, you should be able to start training your “unstable” model with a small
learning rate and after the weights have evolved to more reasonable values, continue
training with a larger, more typical learning rate.)