Does anyone know a convenient and easy way to understand whether your NN is suffering from dead RELU’s in Pytorch?
By dead ReLUs you mean a negative input and thus a zero result?
If so, you could count the zeros in the output activation.
Here is a toy example for a simple model:
model = nn.Sequential(
nn.Linear(10, 10),
nn.ReLU()
)
x = torch.randn(10, 10)
output = model(x)
(output == 0).sum(1).float() / output.size(1)
Fore more complex models, you could use a forward hook to get the intermediate activations.
1 Like