Troubleshooting dead ReLUs (or something else)

I’m implementing a Chinese handwriting recognition model mentioned in this paper. If I purposely overfit a single batch of 2 or 4 samples, training loss does drop to 0 as expected. However, when I train on 40000 samples (composing 200 classes) in batches of 4 and validate on about 8000 samples, the training and validation losses hardly budge at all from epoch to epoch.

A quick web search indicates that this is a symptom of “dead ReLUs.” I doubt the choice of the ReLU is the problem, however, since someone else implemented a slightly simpler model from the same paper on TensorFlow with expected results. I must be doing something wrong.

Could you tell me what’s wrong with my model below or, if not, suggest a way to troubleshoot, please? Thank you.

class CNN(nn.Module):
    # Zhang. M7-1
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
        self.conv4 = nn.Conv2d(256, 512, 3, padding=1)
        self.conv5 = nn.Conv2d(512, 512, 3, padding=1)
        self.fc1 = nn.Linear(512 * 4 * 4, 1024) # 4 = 64/(2^4)
        self.fc2 = nn.Linear(1024, 200)

    def forward(self, X):
        X = F.max_pool2d(F.relu(self.conv1(X)), (2, 2), (2, 2))
        X = F.max_pool2d(F.relu(self.conv2(X)), (2, 2), (2, 2))
        X = F.max_pool2d(F.relu(self.conv3(X)), (2, 2), (2, 2))
        X = F.relu(self.conv4(X))
        X = F.max_pool2d(F.relu(self.conv5(X)), (2, 2), (2, 2))
        
        X = F.dropout(torch.flatten(X, start_dim=1))
        X = F.dropout(F.relu(self.fc1(X)))
        X = self.fc2(X)
        return F.log_softmax(X, dim=1) # log-probabilities

model = CNN()