softmax_out
is a (5, 932, 1682) torch tensor
m_hat = torch.stack([(r+1)*output for r, output in enumerate(softmax_out)], 0)
After this m_hat
is still a (5, 932, 1682) tensor. What does enumerate return? Why it can still keep its shape?
This is related to python3
and not explicitly to pytorch
But anyway to answer your question.
>>> for i, val in enumerate([10, 20, 30, 40, 50]):
>>> print (i, val)
0, 10
1, 20
2, 30
3, 40
4, 50
Also,
In [13]: d = np.array([[4, 5], [6, 7]])
In [14]: for i, val in enumerate(d):
print (i, val)
0 [4 5]
1 [6 7]
So, basically in your example, all you are doing is you are multiplying each dimension with its index. And hence the shape is preserved.