Is it correct way to do cross channel normalization?

You could probably wrap in in a

with torch.no_grad():
    ...

block instead as described here.