Use BatchNorm directly on input

Hi, I was wondering whether it could be useful or harmful to apply batch normalization directly on the input of a nn. Example of a simple network:

class MyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.bn0 = nn.BatchNorm1d(128)
        self.fc1 = nn.Linear(128, 4096)
        self.bn1 = nn.BatchNorm1d(4096)
        self.fc2 = nn.Linear(4096, 4096)
        self.bn2 = nn.BatchNorm1d(4096)
        self.fc3 = nn.Linear(4096, 10)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        normalized_input = self.bn0(x)
        h = self.relu(self.bn1(self.fc1(normalized_input)))
        h = self.relu(self.bn2(self.fc2(h)))
        h = self.fc3(h)
        return h

Suppose I have a huge trainset whose mean and std are unknown (for any possible reason), so I cannot normalize the dataset. By applying a batch norm even before feeding data to the first layer, I should be able to normalize my data anyway. To me it sounds plausible, but I haven’t ever seen a network such that. What are the possible implications of such an approach?

Essentially the answer is yes. At train-time, mean and std will be estimated given the batch. Therefore a larger batch size will be more effective for this technique. Given enough training, you’ll probably have good estimates of your mean and std for your distribution.

However, I don’t see many circumstances where you couldn’t directly compute the mean and std. Even if the dataset is massive, you can compute mean/std precisely without loading everything up in memory. You can do one pass at the dataset to precisely compute the mean and perform another pass to compute the std.