nn.LayerNorm for a specific dimension of my tensor?

I’d like to apply layernorm to a specific dimension of my tensor.

input = torch.randn(N, C, H, W)

In the above example, I’d like to apply layernorm along the C dimension.

Looking at the LayerNorm documentation, as I understand it, you can only tell nn.LayerNorm the size of dimension to which you’d like to apply layernorm. I think this creates a problem if you have 2 dimensions of the same size, and you’d like to apply layernorm to the leftmost dimension.

Concretely, if I do the following, I believe it actually applies layernorm to dimension H, because it is the same size as dimension C, and it is further right in the list of dimensions.

input = torch.randn(N, C, H, W)
layernorm = nn.LayerNorm(C)
output = layernorm(input)

Is there a way around this?

I suppose one solution is to transpose (perhaps using permute) before calling LayerNorm, but that feels a bit inelegant.

1 Like

The approach I ended up using

I ended up using permute to make C the rightmost dimension before LayerNorm, and then permuting again to go back to the original shape.

Let’s do a simpler example with 3 dimensions instead of 4:

import torch
from torch import nn

def get_input_tensor(dims):
    t = torch.zeros(dims)
    t_flat = t.view(t.numel()) # thx: https://discuss.pytorch.org/t/any-alternatives-to-flat-for-tensor/3106

    # fill with something like [[[0,1,2], [3,4,5]]]
    for i in range(t_flat.numel()):
        t_flat[i] = i
    return t

layernorm = nn.LayerNorm(C)

input = get_input_tensor([N,C,W])
x = input.permute(0, 2, 1) # [N, C, W] --> [N, W, C]
x = layernorm(x)
output = x.permute(0, 2, 1) # [N, W, C] --> [N, C, W]

In practice, of course we’d want to put this in an nn.Module and initialize the nn.LayerNorm in the module’s __init__() function.

I haven’t done any careful speed testing to see whether the permute adds much runtime.

Correctness check

I was able to get the above to match the numerics of a hand-coded LayerNorm that operates on the middle dimension of a [N, C, W] input tensor:

# adapted from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/BERT/modeling.py#L317
class LayerNorm_Custom(nn.Module):
    def __init__(self, hidden_size, eps=1e-12):
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.bias = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

    def forward(self, x):
        x = x.permute(0, 2, 1) # [N, C, W] --> [N, W, C]
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.variance_epsilon)
        x = self.weight * x + self.bias
        x = x.permute(0, 2, 1)  # [N, W, C] --> [N, C, W]
        return x