How to map the elements of a tensor via key-value pairs (sparse version of index_select)

Hi All,

I’m trying to transform an input tensor via a key-value mapping (which is defined via 2 tensor representing the keys and values separately) to a desired output. I did think of using a simple for-loop but I fear that that approach would potentially overwrite previous mapping if I had a larger key_tensor and value_tensor.

Is there an existing function in PyTorch can do the following efficiently?

input_tensor = torch.tensor([[ 9, 36, 36],
                             [18, 36, 36],
                             [18, 36, 36],
                             [18, 36, 36]])

key_tensor = torch.tensor([36, 18,  9])
value_tensor = torch.tensor([0, 1, 2])

def expected_func(input, keys, values):
  """
  some amazing code here...
  """
  return expected_output

expected_output = tensor([[ 2, 0,  0],
                          [ 1, 0, 0],
                          [ 1, 0, 0],
                          [ 1, 0, 0]])