Global argmax as index in tensor

I’m new to Torch, and there’s one thing I find counterintuitive in argmin and argmax operations,
so, per example from the docs we have

>>> a = torch.randn(4, 4)
>>> a
tensor([[ 0.1139,  0.2254, -0.1381,  0.3687],
        [ 1.0100, -1.1975, -0.0102, -0.4732],
        [-0.9240,  0.1207, -0.7506, -1.0213],
        [ 1.7809, -1.2960,  0.9384,  0.1438]])
>>> torch.argmin(a)

But wouldn’t it be more natural to return tensor([3,1])? In particular case, I have two distinct sets of points A and B, and I look for the indices of a pair (a,b) ∈ A × B. I can measure pairwise distances using cdist, but a 42-like answer is of little use. Sure the indices can be inferred, but doing so looks messy and superfluous.
Is there any better solution? And also, are there any reason for delivering argmin/argmax this way beyond backward compatibility and seeking in an effectively linear array?