Confuse about how torch.argmax() supporting `axis` parameter inplace of `dim`

torch.argmax

 torch.argmax(input, dim, keepdim=False) → LongTensor

Parameters:

        input (Tensor) – the input tensor.

        dim (int) – the dimension to reduce. If None, the argmax of the flattened input is returned.

        keepdim (bool) – whether the output tensor has dim retained or not. Ignored if dim=None.

There is no mentioned in the official docs neither in the source code but I can pass in axis instead of dim and it works the same. I am mostly curious about how pytorch does it ?

import torch
a = torch.randn(2,2)
b = torch.argmax(a, dim=1)
c = torch.argmax(a, axis=1)
print(b) # tensor([0, 1])
print(c) # tensor([0, 1])

The aliasing should be performed in python_arg_parser.cpp here:

// Default arg name translations for compatibility with NumPy.
//
// Example:
// ```python
// t = torch.randn(10,10)
// torch.sum(a=t, axis=0, keepdim=True)
// ```
//
// A vector is necessary, because we might need to try multiple values.
// In particular, NumPy sometimes uses "x" and sometimes "a" for the main input
// tensor. Rather than annotate each function separately with whether it should
// take "x" or "a", just try both.
//
// TODO: Allow individual functions to specify non-default translations:
// For example, `torch.pow` should translate "exponent" to "x2".
static const std::unordered_map<std::string, std::vector<std::string>>
    numpy_compatibility_arg_names = {
        {"dim", {"axis"}},
        {"keepdim", {"keepdims"}},
        {"input", {"x", "a", "x1"}},
        {"other", {"x2"}},
};
1 Like

I see, this makes it so much clear ! Also, how can we find for what all functions this is applied for or do I have to check numpy docs for the function that has the same name?

I don’t know which functions support this argument overload, but would guess all arguments might be checked for this mapping.

1 Like