Normalization with small batches

Hi,
I have an linear layer that takes as an input a 32 (large) items, for example, with batch size of 2:

tensor([[ 1477.7378,  2719.6460,  1040.3798,  1215.5387,  3542.7241,   266.8930,
          5944.8560, -2105.2986,  4023.8545,  -932.9826,  1617.7998,  4028.0327,
         -2225.9929, -1581.6738, -3424.4250, -1329.6610,  5483.1172,  5975.4663,
          4483.0156,  -373.3383,   482.2250,  -557.7921,  1385.3356,   196.1219,
          -375.9175,  1787.6195, -3516.9705,  -687.9647,  2663.0601,  -576.0363,
            86.2473,  5765.6387],
        [ 1744.0067,  3214.6289,  1228.2401,  1436.6631,  4191.4390,   318.6791,
          7028.2715, -2485.2910,  4753.7217, -1105.6101,  1920.9307,  4758.9941,
         -2634.7754, -1861.2491, -4049.8877, -1578.7089,  6478.5850,  7065.4326,
          5293.2305,  -442.4786,   566.5157,  -658.5790,  1643.0204,   237.8128,
          -450.6484,  2104.6516, -4151.9263,  -815.9917,  3144.3291,  -675.0918,
            94.2207,  6817.9795]])

Due to the big numbers, I’m trying to normalize the inputs. However, I’m using batch size of 4 which makes batch normalization very unstable during training (as it uses mean and variance of the batch).
I tried using layer normalization, but it results in very similar outputs for different inputs, for example for the values above:

tensor([[ 0.1271,  0.5954, -0.0378,  0.0283,  0.9057, -0.3294,  1.8113, -1.2237,
          1.0871, -0.7818,  0.1799,  1.0886, -1.2692, -1.0263, -1.7211, -0.9313,
          1.6372,  1.8229,  1.2602, -0.5708, -0.2482, -0.6403,  0.0923, -0.3561,
         -0.5717,  0.2440, -1.7560, -0.6894,  0.5740, -0.6472, -0.3975,  1.7437],
        [ 0.1264,  0.5955, -0.0382,  0.0283,  0.9070, -0.3283,  1.8119, -1.2227,
          1.0864, -0.7826,  0.1828,  1.0881, -1.2704, -1.0236, -1.7218, -0.9335,
          1.6366,  1.8238,  1.2585, -0.5711, -0.2492, -0.6400,  0.0942, -0.3541,
         -0.5737,  0.2414,  -1.7543, -0.6902,  0.5730, -0.6453, -0.3999,  1.7449]])

This makes it impossible to differentiate between the two examples.
Is there a normalization layer that suits this problem?
Maybe something that performs per feature normalization (like batch normalization) but uses the running mean and variance?
Thanks!

Could you elaborate more on why batch normalization causes training to be unstable when batch size is 4?

Batch normalization uses the mean and the variance of each feature in the batch. 4 examples are not representative of the full dataset (and can be very far from the mean and the variance of the full dataset)

1 Like

And when you say “different inputs”, do you mean different inputs to layer norm, or different inputs to the network?

Different inputs to the layer norm (the output of the layer before)