I want to implement neurons splitting and don’t know how to approach this using Pytorch.

I’m working on an RL problem and observe that the number of dead ReLU neurons (i.e. neurons which are zero forever) in my model constantly increases.

This is at least partially related to the fact that the distribution of the training data changes over time as behavior becomes more complex.

I’d like to try to substitute dead neurons by copies of “alive”:

If neuron A is alive and B is dead (forever zero) I can copy weights/bias of A to B and divide weights for A in the next layer between A and B, like Anew = Bnew = 0.5 * A.

As I understand this won’t change net’s output at all.

To make splitted neurons drift from each other over time I’m going to add some noise, but this is a different story.

So the question is how to implement it in Pytorch in the most correct way?

The problem is how can I tell what’s the next layer and what indices in it correspond to indices in the current one.

The problem may seem trivial if the model consists of a couple of fc layers, but I’ve got a rather complicated model with residual connections, splits and concatenations…

Pytorch builds computation graph, is it possible to extract this information from it?

I would really like not to build workaround here as my model (code) would become extremely fragile.

Roman