If you want to know the number of elements in an entire tensor, torch.Tensor.nelement()
is convenient. Finding the number of elements of a slice is less elegant in Torch than it is in NumPy, where you can use numpy.prod
and a slice of the shape. What’s the idiomatic way to do this in Torch? If the current idiom is to populate a tensor with a slice of the torch.Size
, would be more natural for torch.Tensor.size()
to return a tensor type?
In [1]: import numpy as np; import torch
In [2]: xnp = np.random.normal(size=(3, 4, 5, 6))
In [3]: np3dims = np.prod(xnp.shape[:3])
In [4]: xt = torch.randn((3, 4, 5, 6))
In [5]: torch.prod(xt.size()[:3])
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-5-a2fe80b56024> in <module>()
----> 1 torch.prod(xt.size()[:3])
TypeError: torch.prod received an invalid combination of arguments - got (torch.Size), but expected one of:
* (torch.FloatTensor source)
didn't match because some of the arguments have invalid types: (torch.Size)
* (torch.FloatTensor source, int dim)
* (torch.FloatTensor source, int dim, bool keepdim)
In [6]: np.prod(xt.size()[:3])
Out[6]: 60