Suppose I have a tensor, `a`

, I wish to normalize to the mean and standard deviation along its last axis:

```
a.shape
> torch.Size([B,C,X,Y])
```

I can achieve this using:

```
s_mean = torch.mean(s, axis=-1).unsqueeze(-1)
s_std = torch.std(s, axis=-1).unsqueeze(-1)
s_norm = (s-s_mean)/(s_std)
```

However, there are several entries along the axis `Y`

that have mean 0 and variance 0, ergo they are all the same value.

I would like to only normalize the values with a non-zero mean and standard deviation and set the remaining values to zero.

What functions to I require to achieve this?