Mean of each class of a dataset

How can I compute the mean of each class in a dataset using Pytorch?