Classifier model sometimes outputs 0s, which breaks learning

I am trying to replicate the tutorial model found at What is torch.nn really? — PyTorch Tutorials 2.1.0+cu121 documentation, but in a functional style rather than OO style. When I train the model, it works sometimes, but sometimes the model gets into a state where it will return all 0s for the output for some training and validation cases. When it gets into this state, it fails to learn properly.

Sometimes I’ll get lucky with the random starting weights and training data shuffle and it will train flawlessly without any problems.

My guess is that there is some bug in my model that is causing values to reach extreme numbers (0, inf) or something, but I don’t know what it might be. I believe I have replicated the tutorial model, other than syntax/style, but maybe I’m missing something.

The gist of my model is the following (reformatted a bit to facilitate troubleshooting):

from torch.nn.functional import cross_entropy, conv2d, relu, adaptive_avg_pool2d
from functools import partial
from torch.nn import Parameter
from torch.nn.init import uniform_
empty_tensor = partial(torch.empty, device=device)

def gen_conv2d_parameters(chan_in, chan_out, kernel_size, groups):
    weight_tensor = empty_tensor(chan_out, chan_in // groups, kernel_size, kernel_size)
    bias_tensor = empty_tensor(chat_out)
    bound = math.sqrt(groups / (chan_in * kernel_size * kernel_size))
    weight_parameter = Parameter(uniform_(weight_tensor, -bound, bound))
    bias_parameter = Parameter(uniform_(bias_tensor, -bound, bound))
    # ... store parameters ...
    return weight_parameter, bias_parameter

weight1, bias1 = gen_conv2d_weight(1, 16, 3, 1)
weight2, bias2 = gen_conv2d_weight(16, 16, 3, 1)
weight3, bias3 = gen_conv2d_weight(16, 10, 3, 1)

def model(x):
    x = x.view(-1, 1, 28, 28)
    x = relu(conv2d(x, weight1, bias1, stride=2, padding=1)
    x = relu(conv2d(x, weight2, bias2, stride=2, padding=1)
    x = relu(conv2d(x, weight3, bias3, stride=2, padding=1)
    x = adaptive_avg_pool2d(x, 1)
    x = x.view(-1, 10)

def train_one_batch(input_batch):
    output_batch = model(input_batch)
    loss = cross_entropy(output_batch, expected_batch)

How are you setting the learning rate of your optimizer and have you tried decreasing it?

Learning rate is 0.5. This was selected because the tutorial told me to, and I’m assuming the tutorial trains correctly. What is the correct way to choose a learning rate?

momentum = 0.9 (also because tutorial told me to).

Changing the learning rate to 0.1 makes the problem go away (and it learns much better in general). Changing the learning rate to 0.9 causes it to learn incredibly poorly and the “all 0s” issue comes up frequently.

While interesting, this is not really a solution to my problem because this means that either:

  1. the main PyTorch tutorial does not train correctly
  2. my code doesn’t correctly implement the same algorithm as the tutorial

I find (1) to be quite unlikely, which leaves me with (2) likely unsolved by this adjustment of the learning rate.

Have you verified (1)? I wouldn’t be too surprised if the toy example in the tutorial doesn’t have initialization-dependent convergence. It could also be possible that you are running on a machine (e.g., a GPU with TF32 being used for convolution) that has somewhat different numerical behavior than what the tutorial was tested with (e.g, real single-precision floating-point).

Re: (2) In general, there is no silver bullet solution to selecting a learning rate, and it is natural to tune this depending on the problem setting. See also:

The tutorial does have a step for moving to a GPU (which I followed), but I just tried switching back to CPU and I still sometimes get the all 0s output (just like when on GPU).

I have not. Given that I don’t see anyone else complaining about this anywhere on the forums it feels far fetched that the main tutorial for PyTorch wouldn’t work properly out of the box. Am I being too optimistic here? Would it be appropriate in this case to assume that no one has actually done the tutorial and critically inspected the output? I suppose I can replicate the tutorial exactly and look at its results myself, though I would rather not go through all of that effort unnecessarily.

Tangentially to the questions regarding the tutorial, what is happening internally that results in this behavior? It feels like all 0s is strictly incorrect and the model should be somehow protected against getting to that point. Is there some way to ensure that the model can get close to 0 but never 0 so that it can recover from “too small numbers”?

BTW, note that the learning rate in the tutorial is updated to 0.1 here:

In general, there are methods (e.g., different optimizers such as Adam vs. SGD or changes to the model architecture such as adding normalization layers) in addition to lowering the learning rate to prevent the model from diverging.
Some related concepts:
Dying ReLU: machine learning - What is the "dying ReLU" problem in neural networks? - Data Science Stack Exchange
Vanishing gradient problem - Wikipedia

1 Like

Ahah! I think that is the bit I missed! Somehow I didn’t notice that one line change. :confounded: Thanks, your instincts appear to have been spot on. :tada:

For future readers: Switching to leaky_relu for my convolutional layers allowed me to bump the learning rate up to 0.9 without running into the 0s problem I had previously with learning rates set to either 0.5 and 0.9 (and presumably anything in between).