I’d need expertise for the following problem… thanks a lot, your help is very appreciated!
I am dealing with a dataset for which:
- Samples have 10 continuous features: x = (x1, x2, …, x10)
- certain samples often have a missing (continuous) feature x1, that ranges say from 0 to 100.
- When it is missing, I give it a default null value : x1 = 0, and I have a categorical feature x11 = 0
- When x1 is not missing, it usually is within the range x1 = 0 - 100; and I have x11 = 1 to indicate x1 is present.
When I try to have a shallow torch MLP learn the identity mapping MLP(x) → x1 (I make sure to only penalize the samples which have x1), I notice that learning goes smoothly usually, EXCEPT if I use batch normalization.
I can’t remove the samples for which x11 = 0 due to future work with these (the problem is slightly more complex, I am actually dealing with a graph and each of those samples are nodes within a graph…)
My guess is that the fact that x1 is missing for a lot of samples results in batches have very varying mean and variances, causing the BN’s running mean and variance to be unable to stabilize, and ultimately causing the MLP to be unable to learn the simple mapping above. I tried training it on a fixed batch to test the hypothesis,
Is there a way to keep the normalization while also keeping the samples which have missing x1 in the training set ? Should I encode the fact x1 is missing for some samples in some other way ?
For those who are interested in the bigher
EDIT: I will be using [2111.12128] On the Unreasonable Effectiveness of Feature propagation in Learning on Graphs with Missing Node Features to propagate info to missing nodes for my problem… still, for the non-graph problem exposed above, I find intriguing the fact that batch norm applied to a whole batch breaks convergence when the loss doesn’t use all samples (but only the ones with non-missing x1).