Idiomatic way to compute num elements of a tensor slice?

If you want to know the number of elements in an entire tensor, torch.Tensor.nelement() is convenient. Finding the number of elements of a slice is less elegant in Torch than it is in NumPy, where you can use numpy.prod and a slice of the shape. What’s the idiomatic way to do this in Torch? If the current idiom is to populate a tensor with a slice of the torch.Size, would be more natural for torch.Tensor.size() to return a tensor type?

In [1]: import numpy as np; import torch

In [2]: xnp = np.random.normal(size=(3, 4, 5, 6))

In [3]: np3dims = np.prod(xnp.shape[:3])

In [4]: xt = torch.randn((3, 4, 5, 6))

In [5]: torch.prod(xt.size()[:3])
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-5-a2fe80b56024> in <module>()
----> 1 torch.prod(xt.size()[:3])

TypeError: torch.prod received an invalid combination of arguments - got (torch.Size), but expected one of:
 * (torch.FloatTensor source)
      didn't match because some of the arguments have invalid types: (torch.Size)
 * (torch.FloatTensor source, int dim)
 * (torch.FloatTensor source, int dim, bool keepdim)

In [6]: np.prod(xt.size()[:3])
Out[6]: 60

the least ugly i could come up with is:

import operator
reduce(operator.mul, list(x.size())[:3])
3 Likes

it is painful that size does not return the number of elements, while shape returns the shape, as in numpy

You can just use np.prod. it works with any thing that is “array-like”, including tuples, lists, and torch.Size.

import numpy as np; import torch

t = torch.zeros((3, 4))
print(np.prod(t.shape))
# prints "12"
3 Likes