Is there a simple way to prevent zero gradients from being averaged when certain batched input produces zero gradients?

Hi, experts. I have a model where some layers are only used in some datapoints.
For example, let’s say only two out of 128 datapoints in a batch (bs=128) uses that particular layer.
My problem is, I want to train this layer, but the gradient becomes so tiny as 126 out of 128 inputs in the batch produces 0 gradient for that layer, and the gradient is averaged across the 128 inputs.

I want to know if there’s a way for me to ignore zero gradients for irrelevant inputs in the batch (126) and only average gradient for the two relevant samples (2) for these particular layers, while averaging across all 128 samples for other layers.

Thank you for the help.

This is not easily possible, but it should be possible to use some normalization technique to avoid that the gradients vanish (e.g. for the adaptive optimizers like Adam or LAMB, it would not matter that much unless the number of “nonzero gradient examples” varies starkly from one to the other).
At which point do you know if you will get zero gradients for an example?
If you can tell by the (unreduced) loss, the easiest might be to scale the loss appropriately. Or you could try to find the first backpropagation edge when you can tell and scale there.

Of course, the other thing to keep in mind is that if you’re only getting a training signal from 2% of your data, you are not using the data that efficiently, which may be a problem in its own.

Best regards