I want to apply a normalization along channel dimension with data shape NCHW. However, torch.nn.LayerNorm
only support normalize over last D dimensions (I wanna skip spatial dimension).
This requirement occurs in networks like NAFNet (a popular network for image restoration), they have to implement a special layernorm by themselves, but that would slower than possible fast cuda implementation provided by torch teams.