Hmm… Let’s do some checks.
import torch
from torch import nn
x = torch.randn(20, 6, 10, 10)
g = nn.GroupNorm(1, 6, affine=False)
l = nn.LayerNorm((6, 10, 10), elementwise_affine=False)
y_g = g(x)
y_l = l(x)
(y_g - y_l).pow(2).sum().sqrt()
# tensor(5.1145e-06)
Now, when you type nn.LayerNorm(6)
, you’re instructing torch
to compute the normalisation over a single dimension, i.e. the last one. So, when you feed the permutated input to your LayerNorm
module, it will compute a normalisation only over the 6 channels, per every location of the map. No wonder you got a difference of 55.1597
in your first snippet.
Finally, GroupNorm
uses a (global) channel-wise learnable scale and bias, while LayerNorm
has a (local) scale and bias for each location as well. Unless you share them across all locations for LayerNorm
, LayerNorm
will be more flexible than GroupNorm
using a single group. You can see how their CPP implementation differs below.
// global scale and bias
for (const auto k : c10::irange(HxW)) {
Y_ptr[k] = scale * X_ptr[k] + bias;
}
// per location scale and bias
vec::map3<T>(
[scale, bias](Vec x, Vec gamma, Vec beta) {
return (x * Vec(scale) + Vec(bias)) * gamma + beta;
},
Y_ptr,
X_ptr,
gamma_data,
beta_data,
N
);
Where map
tells you that (citing Wikipedia):
a simple operation is applied to all elements of a sequence, potentially in parallel [1]. It is used to solve embarrassingly parallel problems: those problems that can be decomposed into independent subtasks, requiring no communication/synchronization between the subtasks
Answering your question, GroupNorm(num_groups=1)
and LayerNorm
are not equivalent, unless followed by a fully-connected layer.