you can do it for any tensor or batch of tensors in 3 steps:
import torch
a = torch.randn(2,3,4, 4) # batch=2, channel=3,H=4,W=4 and I want to normalize across per channel
mu = torch.mean(a,dim=(2,3),keepdim=True)
sd = torch.std(a,dim=(2,3),keepdim=True)
normalized_res = (a - mu)/sd