How to do normalization over specific dimension?

I know LayerNorm and InstanceNorm1d can do normalization over the last dimension. If I want to do it over the first dimension, I have to transpose the input tensor or calculate the mean and standard-deviation by my self, which would consume much time and memory.

if I do it by my self or transpose the input:

class Net1(nn.Module):
    def __init__(self):
        super(Net1, self).__init__()
        
    def forward(self, x):
        return (x - x.mean(dim=0, keepdim=True))/ (x.var(dim=0, unbiased=False, keepdim=True) + 1e-5).sqrt()


class Net2(nn.Module):
    def __init__(self):
        super(Net2, self).__init__()
        
    def forward(self, x):
        return nn.functional.layer_norm(x.transpose(0, 2), (1000,)).transpose(0, 2)
    

class Net3(nn.Module):
    def __init__(self):
        super(Net3, self).__init__()
        
    def forward(self, x):
        return nn.functional.instance_norm(x.transpose(0, 2)).transpose(0, 2)

net1 = Net1()
net2 = Net2()
net3 = Net3()

x = torch.randn(1000, 100000, 10, requires_grad=True)

%timeit net1(x)   #4.7 s ± 32.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit net2(x)   #2.71 s ± 46.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit net3(x)   #3.33 s ± 82.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit net1(x).mean().backward()   #22.1 s ± 52.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit net2(x).mean().backward()   #11.3 s ± 169 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit net3(x).mean().backward()   #10.3 s ± 134 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

But if I do normalization over the last dimension without transpose, I can save half of time.

y = torch.randn(10, 100000, 1000, requires_grad=True)

class Net5(nn.Module):
    def __init__(self):
        super(Net5, self).__init__()
        
    def forward(self, x):
        return nn.functional.layer_norm(x, (1000,))
    

class Net6(nn.Module):
    def __init__(self):
        super(Net6, self).__init__()
        
    def forward(self, x):
        return nn.functional.instance_norm(x)

net5 = Net5()
net6 = Net6()

%timeit net5(y)   #1.24 s ± 22.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit net6(y)   #1.73 s ± 16.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit net5(y).mean().backward()   #4.07 s ± 28.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit net6(y).mean().backward()   #4.69 s ± 32.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Is there any method to do it over specific dimension effectively?

What they should be doing is adding another argument to nn.LayerNorm(axsi/dim=-1)(x) so that users can specify which axis they want, especially useful in situations where you don’t know a priori the shapes of the inputs.