Conditional Normalization of Pytorch Tensor

Suppose I have a tensor, a, I wish to normalize to the mean and standard deviation along its last axis:

a.shape

> torch.Size([B,C,X,Y])

I can achieve this using:

s_mean = torch.mean(s, axis=-1).unsqueeze(-1)
s_std = torch.std(s, axis=-1).unsqueeze(-1)
s_norm = (s-s_mean)/(s_std)

However, there are several entries along the axis Y that have mean 0 and variance 0, ergo they are all the same value.

I would like to only normalize the values with a non-zero mean and standard deviation and set the remaining values to zero.

What functions to I require to achieve this?

I hope that you managed to find a solution since it’s been 2 years you asked but I put an answer here in case someone has the same issue.

You can use torch.where to provide a value to some indices of a tensor given a condition (i.e. a bool tensor). Also you can use the kwarg keepdim=True in your operations to avoid having to unsqueeze just after.
Concretely, something like this should solve your issue:

s_mean = torch.mean(s, dim=-1, keepdim=True)
s_std = torch.std(s, dim=-1, keepdim=True)
s_norm = torch.where(s_std != 0, (s - s_mean) / s_std, 0)