Group Normalization Implementation and Instance Normalization 4d

Recently,the Group Normalization article is very famous. I try to implement it in pytorch and my ideal is to reshape input and utilize Instance Norm in pytorch. In 2d Images, I can easily implement it using this code:

input=torch.randn(N,C,H,W)
group_size=G
gn_func=nn.InstanceNorm3d(G)
# implement GN
input=input.view(N,G,C/G,H,W)
input=gn_func(input)
input=input.view(N,C,H,W)

And this works well. But when I need to process 3d data, I find that there doesn’t exist InstanceNorm4d, and I find that all norm functions only work for 1d,2d,3d data. So my problem is how can we do batch norm or instance norm in 4d datas like(N,C,T,D,H,W). In tensorflow I can do batch norm in any axis, but I don’t know how to do it in Pytorch.
Or if you can give me some other suggestions on how to implement GN for 3d data,that will be great.
Thanks a lot!

1 Like