How to apply layernorm to image data with shape N C H W?

I want to apply a normalization along channel dimension with data shape NCHW. However, torch.nn.LayerNorm only support normalize over last D dimensions (I wanna skip spatial dimension).

This requirement occurs in networks like NAFNet (a popular network for image restoration), they have to implement a special layernorm by themselves, but that would slower than possible fast cuda implementation provided by torch teams.