How can I use the argmax values to index a tensor?
So, for example, I have two tensors of the same shape x,y and have the argmax = x.min(-1) of one of them. Then I want to get the values at the position in y i.e. y[argmax] ?
How can I do that ?
How can I use the argmax values to index a tensor?
So, for example, I have two tensors of the same shape x,y and have the argmax = x.min(-1) of one of them. Then I want to get the values at the position in y i.e. y[argmax] ?
How can I do that ?
Your argmax should rather be:
argmax = x.max(0)[1]
since max
returns a tuple.
Then you can use it as you did:
y[argmax]
x = torch.rand(3,4)
y = torch.rand(3,4)
argmax = x.max(0)[1]
y[argmax]
Does not work for me.
You probably want to use arange(4) as a second index or use gather or so.
Best regards
Thomas
going wtih Tom’s idea, and tweaking the earlier code a bit:
import torch
import numpy as np
torch.manual_seed(4)
x = torch.rand(3,4)
y = torch.rand(3,4)
print('y', y)
_, argmax = x.max(-1)
print('argmax', argmax)
y[np.arange(3), argmax] = 3
print('y', y)
Result, as required:
y
0.9263 0.4735 0.5949 0.7956
0.7635 0.2137 0.3066 0.0386
0.5220 0.3207 0.6074 0.5233
[torch.FloatTensor of size 3x4]
argmax
0
2
1
[torch.LongTensor of size 3]
y
3.0000 0.4735 0.5949 0.7956
0.7635 0.2137 3.0000 0.0386
0.5220 3.0000 0.6074 0.5233
[torch.FloatTensor of size 3x4]
Seems this will help you. https://github.com/Zhaoyi-Yan/Shift-Net_pytorch/blob/master/util/MaxCoord.py
This is for assigning the channel 1
, of which has the max value across all channels.
Using torch.gather could work. Yet, it will first generate a tensor with the same size as the original tensor. If any one have a better way?
def score_max(x, dim, score):
_tmp=[1]*len(x.size())
_tmp[dim] = x.size(dim)
return torch.gather(x,dim,score.max(
dim)[1].unsqueeze(dim).repeat(tuple(_tmp))).select(dim,0)