Splitting neurons, best practice

I want to implement neurons splitting and don’t know how to approach this using Pytorch.
I’m working on an RL problem and observe that the number of dead ReLU neurons (i.e. neurons which are zero forever) in my model constantly increases.
This is at least partially related to the fact that the distribution of the training data changes over time as behavior becomes more complex.

I’d like to try to substitute dead neurons by copies of “alive”:
If neuron A is alive and B is dead (forever zero) I can copy weights/bias of A to B and divide weights for A in the next layer between A and B, like Anew = Bnew = 0.5 * A.
As I understand this won’t change net’s output at all.
To make splitted neurons drift from each other over time I’m going to add some noise, but this is a different story.

So the question is how to implement it in Pytorch in the most correct way?
The problem is how can I tell what’s the next layer and what indices in it correspond to indices in the current one.
The problem may seem trivial if the model consists of a couple of fc layers, but I’ve got a rather complicated model with residual connections, splits and concatenations…
Pytorch builds computation graph, is it possible to extract this information from it?
I would really like not to build workaround here as my model (code) would become extremely fragile.



I am not sure to follow what you want to do here.

As I understand this won’t change net’s output at all.

If you change the value of the activations, it will change the value of the next layer. Because it is very unlikely that the next layer treats A and B exactly the same way.
If you change the value of the weights, it will change also as it is very unlikely that the input for A and B are the same.

But what if you want to do is update the weights of a single layer (let’s say Linear), you can simply do

with torch.no_grad():
  layer.weight[alive_idx] /= 2
  layer.weight[dead_idx] = layer.weight[alive_idx]

  layer.bias[alive_idx] /= 2
  layer.bias[alive_idx] = layer.bias[alive_idx]

Thank you for the reply!

Here is my very simple diagram of what I want to achieve:

I understand that identical subnets will forever be the same, so I’m going to add some noise to them to make them drift away from each other over time.
Also new weights don’t have to be 50% of original, they only need to sum to original.
But again this is a different story.

My question is is it possible somehow automatically (probably hacky) to detect indices in the next layer corresponding to the given dead/alive neurons in the current layer?
For a simple feed forward net like on the diagram it’s trivial, but in my model after relu I split/concatenate/sum tensors so it’s rather problematic to keep that knowledge “externally”.
I know Pytorch builds calculation graph as operations get applied to the tensors, is it possible to extract such kind of info from the graph?



There is no automatic way to do that for sure.
And it’s not even always true that you will have a change to the downstream layer that will give you what you want. In particular, you mention “sum”, that could lead to multiple nodes being used with the same weights and so you cannot modify these ways for some nodes and not others.

1 Like

Thanks for clarification!

I’m actually working on an algorithm for this very thing. Currently, the algorithm only works on Linear and Conv2D layers. GRU is in the works. Additionally, it will be updated soon to include removing target neurons.
How neurons are targetted for splitting is by taking an average of the absolute value of the gradients during an epoch. Then any neurons that are above the cutoff will get split.
But what you are referring to is not neuron splitting, but neuron removal. When that update is released, it will target neurons with gradient movement below the threshold you set.

Here is the Github with working examples and tests showing how this performs vs non-splitting.

Please note that you can use this code for personal or research projects freely. If for other purposes, such as commercial, please contact me for licensing and support/development.

1 Like

Sorry for the delay. I have updated the function to be able to remove neurons deterministically via cutoffrem. I’ve also provided some simple steps for setup and a script that can be imported. Link is already provided above. It works for Linear and Conv2D layers for now. RNN layers like GRU or LSTM cannot be handled this way due to the hidden state(tried several approaches, but all suffered poor performance). Please let me know if you find any bugs.