Use tensor.mean() on column ignoring certain numbers

This is revisit this old question: How about mean on the columns for 2D array? torch.mean can take parameter dim to return mean for each column. Can we do so with mask filtering out certain bad values? Although we can loop through each column like following, is there better way?

       for i in  range(y['train'].shape[1]):
          mask=y['train'][:,i]!=bad_value
          masked_y=y['train'][:,i][mask]
          y_mean = torch.mean(masked_y)
          y_std = torch.std(masked_y)
          y['train'][:,i][mask]=(y['train'][:,i][mask]- y_mean) / y_std
          mask=y['val'][:,i]!=bad_value
          y['val'][:,i][mask]=(y['val'][:,i][mask]- y_mean) / y_std
          mask=y['test'][:,i]!=bad_value
          y['test'][:,i][mask]=(y['test'][:,i][mask]- y_mean) / y_std

Hi Jerron!

This looks like one of the use cases for which pytorch’s prototpye MaskedTensor
was developed.

Here’s a simple example:

>>> import torch
>>> torch.__version__
'2.0.1'
>>> t = torch.arange (3.)
>>> t
tensor([0., 1., 2.])
>>> m = t != 2.
>>> m
tensor([ True,  True, False])
>>> mt = torch.masked.masked_tensor (t, m)
<path_to_pytorch_install>\torch\masked\maskedtensor\core.py:156: UserWarning: The PyTorch API of MaskedTensors is in prototype stage and will change in the near future. Please open a Github issue for features requests and see our documentation on the torch.masked module for further information about the project.
  warnings.warn(("The PyTorch API of MaskedTensors is in prototype stage "
>>> mt
MaskedTensor(
  [  0.0000,   1.0000,       --]
)
>>> t.mean()
tensor(1.)
>>> mt.mean()
MaskedTensor(  0.5000, True)

Does this do what you want?

Best.

K. Frank

1 Like

Thanks! It is exactly what I’m looking for. :+1:

Hi Frank, May I follow up with the operation on the masked_tensor?
So I got the mean for each columns of the masked matrix:

  mask=y!=bad_value
  my = torch.masked.masked_tensor (y, mask)
  y_mean,y_std=torch.mean(my,dim=0),torch.std(my,dim=0)
  y_mean,y_std
  (MaskedTensor(
   [  0.0643,   0.2096,  -0.0998]
 ),
 MaskedTensor(
   [  0.4513,   0.4316,   0.4149]
 ))

How to use it, for example, to get normalized value? It can be substracted and divided from normal tenor. But not with two masked tenor:
This succeed:

my.shape, y.shape, (y-y_mean).shape,(y/y_std).shape
(torch.Size([9933, 3]), torch.Size([9933, 3]), torch.Size([9933, 3]), torch.Size([9933, 3]))

and this failed:

(y-y_mean)/y_std
ValueError                                Traceback (most recent call last)

<ipython-input-17-7d6c5321b00a> in <cell line: 1>()
----> 1 (y-y_mean)/y_std

5 frames

/usr/local/lib/python3.10/dist-packages/torch/masked/maskedtensor/binary.py in _binary_helper(fn, args, kwargs, inplace)
     83 
     84     if not _masks_match(*args[:2]):
---> 85         raise ValueError(
     86             "Input masks must match. If you need support for this, please open an issue on Github."
     87         )

ValueError: Input masks must match. If you need support for this, please open an issue on Github.

This also failed:

my-y_mean
ValueError                                Traceback (most recent call last)

<ipython-input-20-b1b0b5b17e0f> in <cell line: 1>()
----> 1 my-y_mean

5 frames

/usr/local/lib/python3.10/dist-packages/torch/masked/maskedtensor/binary.py in _binary_helper(fn, args, kwargs, inplace)
     83 
     84     if not _masks_match(*args[:2]):
---> 85         raise ValueError(
     86             "Input masks must match. If you need support for this, please open an issue on Github."
     87         )

ValueError: Input masks must match. If you need support for this, please open an issue on Github.

Thanks again!

Hi Jerron!

The short story is that it appears that MaskedTensors don’t (currently)
broadcast against other MaskedTensors.

The somewhat longer story is that (for legitimate semantic reasons) pytorch
doesn’t like to perform, say, element-wise addition on MaskedTensors whose
masks differ. This probably bleeds over into the broadcasting case because
you would have to sort through what you want the semantics of broadcasting
the masks to be.

One straightforward approach to your use case: Use MaskedTensors to
compute the .mean() and .std() and then convert them to regular Tensors
to perform the normalization with broadcasting.

Consider:

>>> import torch
>>> print (torch.__version__)
2.0.1
>>>
>>> bad_value = 3.0
>>> y = torch.arange (15.).reshape (5, 3)
>>>
>>> mask = y != bad_value
>>> my = torch.masked.masked_tensor (y, mask)
<path_to_pytorch_install>\torch\masked\maskedtensor\core.py:156: UserWarning: The PyTorch API of MaskedTensors is in prototype stage and will change in the near future. Please open a Github issue for features requests and see our documentation on the torch.masked module for further information about the project.
  warnings.warn(("The PyTorch API of MaskedTensors is in prototype stage "
>>> y_mean, y_std = torch.mean (my, dim = 0), torch.std (my, dim = 0)
>>>
>>> my
MaskedTensor(
  [
    [  0.0000,   1.0000,   2.0000],
    [      --,   4.0000,   5.0000],
    [  6.0000,   7.0000,   8.0000],
    [  9.0000,  10.0000,  11.0000],
    [ 12.0000,  13.0000,  14.0000]
  ]
)
>>> y_mean, y_std
(MaskedTensor(
  [  6.7500,   7.0000,   8.0000]
), MaskedTensor(
  [  5.1235,   4.7434,   4.7434]
))
>>>
>>> my.shape, y.shape, (y - y_mean).shape, (y / y_std).shape
(torch.Size([5, 3]), torch.Size([5, 3]), torch.Size([5, 3]), torch.Size([5, 3]))
>>>
>>> y - y_mean    # MaskedTensor y_mean broadcasts against regular Tensor y
MaskedTensor(
  [
    [ -6.7500,  -6.0000,  -6.0000],
    [ -3.7500,  -3.0000,  -3.0000],
    [ -0.7500,   0.0000,   0.0000],
    [  2.2500,   3.0000,   3.0000],
    [  5.2500,   6.0000,   6.0000]
  ]
)
>>>
>>> # my - y_mean   # fails with ValueError -- y_mean doesn't broadcast against MaskedTensor my
>>>
>>> # one way to do this -- convert to regular Tensors for broadcasting ...
>>>
>>> nan = float ('nan')   # use nans to track masked values
>>>
>>> y_norm = (my.to_tensor (nan) - y_mean.to_tensor (nan)) / y_std.to_tensor (nan)   # regular Tensors will broadcast
>>> my_norm = torch.masked.masked_tensor (y_norm, ~y_norm.isnan())   # convert back to MaskedTensor, if desired
>>>
>>> y_norm
tensor([[-1.3175, -1.2649, -1.2649],
        [    nan, -0.6325, -0.6325],
        [-0.1464,  0.0000,  0.0000],
        [ 0.4392,  0.6325,  0.6325],
        [ 1.0247,  1.2649,  1.2649]])
>>> my_norm
MaskedTensor(
  [
    [ -1.3175,  -1.2649,  -1.2649],
    [      --,  -0.6325,  -0.6325],
    [ -0.1464,   0.0000,   0.0000],
    [  0.4392,   0.6325,   0.6325],
    [  1.0247,   1.2649,   1.2649]
  ]
)
>>>
>>> # check normalization
>>> y_norm.mean (dim = 0)
tensor([nan, 0., 0.])
>>> y_norm.std (dim = 0)
tensor([nan, 1., 1.])
>>>
>>> my_norm.mean (dim = 0)
MaskedTensor(
  [ -0.0000,   0.0000,   0.0000]
)
>>> my_norm.std (dim = 0)
MaskedTensor(
  [  1.0000,   1.0000,   1.0000]
)

Best.

K. Frank