CNN (sometimes) starts outputting only zeros and does not learn during training?

While I am working on an RL problem I am tagging this under vision because it seems to be a problem with a convolutional network I am using. I have tried my code on a different problem with a non-convolutional architecture and it worked fine.

I’m implementing a Reinforcement Learning algorithm to learn Atari games using the architecture from Human Level Control Through Deep Reinforcement Learning for the Q-network. The network takes the last four frames in grayscale, scaled to 84x84, so the input is a (4, 84, 84) tensor and outputs a 1d tensor of as many scalars as there are actions in the game (4 in the case of Breakout, which I tried). It has three convolutional layers and then two linear and a ReLU activation after each layer. The exact architecture can be seen in the code example below.

During training I noticed that after one or two updates the Q-network outputs only tensors of zeros for any input. This goes on and the network does not seem to be learning, even after thousands of SGD updates with a minibatch size of 32. I realize that to properly learn an Atari game I will need much more updates but if this is an error I want to weed it out now before wasting compute on training longer.

To sanity check my architecture I tried it on a toy problem: I’m generating random (4,84,84) tensors and on some of them I’m replacing every other row in the first channel with a row of 1. The initialization code is similar to how I am initializing my real model.

Initialization

import numpy as np
import torch
from torch.nn import Sequential, ReLU, Linear, Conv2d, Flatten, MSELoss
  

  model = Sequential()
  
  \# input has shape (4,84,84)
  model.append(
          Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4)
      )
      model.append(
          ReLU()
      )
      model.append(
          Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)
      )
      model.append(
          ReLU()
      )
      model.append(
          Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)
      )
      model.append(
          ReLU()
      )
      model.append(
          Flatten(start_dim=1)
      )
      model.append(
          Linear(in_features=3136, out_features=512)
      )
      model.append(ReLU())
      model.append(
          Linear(in_features=512, out_features=4)
      )
      model.append(ReLU())
  \# output has shape (4)
  
  optimizer = torch.optim.Adam(model.parameters(), lr=0.00025)

Training and Evaluation (Toy Problem)

\# creates training data which is partially random (4,84,84) tensors and tensors 
\# that have been modified by adding rows of 1s. The model should learn to
\# [0., 1., 0., 0.] for modified tensors and [0. ,0., 0., 0. ] otherwise 
def create_observations_and_targets(n):
    observations = torch.zeros((n, 4, 84, 84))
    ground_truths = np.random.randint(0, 2, n)

    for i, ground_truth in enumerate(ground_truths):
        if ground_truth == 1:
            for j in range(0, 84, 2):
                observations[i][0][j] = torch.ones(84)

    targets = torch.zeros((n, 4))
    for i, ground_truth in enumerate(ground_truths):
        if ground_truth == 1:
            targets[i] = torch.tensor([0., 1., 0., 0.])

    return observations, targets
  
\# do 500 updates with batches of size 32
for _ in range(500):
      training_observations, training_targets = create_observations_and_targets(32)

      outputs = model(training_observations)

      loss = MSELoss(reduction="none")(outputs, training_targets).mean()

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

evaluation_observations, evaluation_targets = create_observations_and_targets(32)
outputs = model(evaluation_observations)
evaluation_loss = MSELoss(reduction="mean")(vanilla_outputs, evaluation_targets)

When I run the training and evaluation code above, in about half of all cases the model learns successfully and evaluation loss is zero in the end. In the remaining cases, the model starts putting out [0, 0, 0, 0] for every input does not seem to learn… Is this a common occurrence during the training of convolutional networks and can be fixed by e.g. increasing the number of training steps (in the toy example going from 50 to 500 did not seem to help)? Or is it likely to be an problem with the architecture or how it is trained? If so, what can be done to reduce this behavior?

Basically I am wondering if when I train the architecture on an Atari game like Breakout I can expect it to get better eventually, or if I would potentially waste dozens of hours of training time?

Could you check if removing the last nn.ReLU might help?
Negative outputs would be clipped to max(0, x) which would result in a zero gradient:

x = torch.tensor([-1., -2., -3.], requires_grad=True)
out = F.relu(x)
out.mean().backward()
print(x.grad)
# tensor([0., 0., 0.])

Yes, removing the last ReLU activation helped! Thanks a lot!