Does anyone know a convenient and easy way to understand whether your NN is suffering from dead RELU’s in Pytorch?
By dead ReLUs you mean a negative input and thus a zero result?
If so, you could count the zeros in the output activation.
Here is a toy example for a simple model:
model = nn.Sequential( nn.Linear(10, 10), nn.ReLU() ) x = torch.randn(10, 10) output = model(x) (output == 0).sum(1).float() / output.size(1)
Fore more complex models, you could use a forward hook to get the intermediate activations.