I have no idea how torch.mean(dim=0) is working.
for example when i run the code below
#------------------------------------------------
import torch
t = torch.FloatTensor([[1, 2], [3, 4]])
print(t.mean(dim=0)
#------------------------------------------------
The output is
tensor([2., 3.])
I don’t know why the output comes like this.
Anyone with the insight please tell me how it is working like this.
spanev
(Serge Panev)
2
It means: give me the mean of the elements iterated of the axis dim
.
In your example what it does is:
(t[0] + t[1]) / num_elements = ([1, 2]+ [3, 4]) / 2 = [2., 3.]
Similarly if you specified dim1
, it would do:
(t[:,0] + t[:,1]) / 2 = tensor([ 7., 13.])
same result as
>>> t.mean(dim=1)
tensor([ 7., 13.])
The dim parameter defines over which dimension of the tensor you are taking a mean. In your case, dim=0
t = [[1,2],
[3,4]]
t.mean(dim=0) = [(1+3)/2, (2+4)/2] = [2,3]
For dim=1,
t = [[1,2],
[3,4]]
t.mean(dim=1) = [(1+2)/2, (3+4)/2] = [1.5,3.5]
Hope this helps!
Regards
Helped me alot
thank you!