Returning the indices of max elements in 1D tensor


I want a function that returns the indices of the max elements of a 1D tensor. There will not be only
single max value. torch.max() method seems to return only one index, but I want multiple indexes.
if a=tensor([1,4,56,6,7,56,7]) I want to get [2,5]