It seems no different between these .
>>> import torch
>>> a=torch.randn(4,4)
>>> a
tensor([[ 2.0155, -1.1138, -1.7522, 0.7299],
[ 1.0620, 0.1840, 0.2790, 1.1942],
[ 1.7519, 1.8871, 1.4988, -0.1911],
[-1.6222, -0.2044, 1.6316, -1.0949]])
>>> a.max(1,keepdim=True)[1]
tensor([[0],
[3],
[1],
[2]])
>>> a.max(1,keepdim=True)[1].data
tensor([[0],
[3],
[1],
[2]])