The approach I ended up using
I ended up using permute
to make C
the rightmost dimension before LayerNorm, and then permuting again to go back to the original shape.
Let’s do a simpler example with 3 dimensions instead of 4:
import torch
from torch import nn
def get_input_tensor(dims):
t = torch.zeros(dims)
t_flat = t.view(t.numel()) # thx: https://discuss.pytorch.org/t/any-alternatives-to-flat-for-tensor/3106
# fill with something like [[[0,1,2], [3,4,5]]]
for i in range(t_flat.numel()):
t_flat[i] = i
return t
N=1
C=3
W=3
layernorm = nn.LayerNorm(C)
input = get_input_tensor([N,C,W])
x = input.permute(0, 2, 1) # [N, C, W] --> [N, W, C]
x = layernorm(x)
output = x.permute(0, 2, 1) # [N, W, C] --> [N, C, W]
In practice, of course we’d want to put this in an nn.Module
and initialize the nn.LayerNorm
in the module’s __init__()
function.
I haven’t done any careful speed testing to see whether the permute
adds much runtime.
Correctness check
I was able to get the above to match the numerics of a hand-coded LayerNorm that operates on the middle dimension of a [N, C, W] input tensor:
# adapted from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/BERT/modeling.py#L317
class LayerNorm_Custom(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, x):
x = x.permute(0, 2, 1) # [N, C, W] --> [N, W, C]
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
x = self.weight * x + self.bias
x = x.permute(0, 2, 1) # [N, W, C] --> [N, C, W]
return x