Hi I am wondering what’s the difference between the newly added take_along_dim
& gather
? It seems that they are pretty much similar other than take_along_dim
can have user not specifying dim
parameter. In this case seems that torch can find the best broadcast, for example I tried logits
of shape (N, K) and use index of logits.argmax(-1)
i.e. shape (N,) it can still find the correct broadcast (N, 1), which I believe is not possible by the traditional numpy broadcasting.
Thanks!