I want to calculate mean
by y-axis
, so the output should be 1x3.
How should I do?
Thank you.
I want to calculate mean
by y-axis
, so the output should be 1x3.
How should I do?
Thank you.
import torch
x = torch.rand(3, 3)
print(x.mean(1))
x.mean(0)
was what I wanted.
Thank you so much^^!!