Checking Whether Network is Dead, Pytorch?

Does anyone know a convenient and easy way to understand whether your NN is suffering from dead RELU’s in Pytorch?

By dead ReLUs you mean a negative input and thus a zero result?
If so, you could count the zeros in the output activation.
Here is a toy example for a simple model:

model = nn.Sequential(
    nn.Linear(10, 10),
    nn.ReLU()
)

x = torch.randn(10, 10)
output = model(x)
(output == 0).sum(1).float() / output.size(1)

Fore more complex models, you could use a forward hook to get the intermediate activations.

1 Like