Altering a Tensor Along a Variable Dimension

Hi all,

Is there a way to alter a tensor given an input dimension and indexes? As an example, suppose I wanted to multiply the first and second index of the third dimension of a tensor by 2. I could do:

t = torch.randn([5, 5, 5, 5, 5])
t[:, :, [1,2], :, :] *= 2

But suppose I don’t know ahead of time the dimension. Is there a way to do the following in a nice clean way?

def multiply_by_2(t, idx, dim):
  if dim == 0:
    t[idx, ...] *= 2
  if dim == 1:
    t[:, idx, ...] *=2

In other words, is there something like index_select that allows me to change the original tensor?

Thanks

Perhaps there is a better way, but here’s an ugly solution involving transposing the desired dimension into the first position, making the change there, and then transposing it back.

t = torch.randn([4, 4, 4])  # our starting tensor

 # your way
t0 = torch.clone(t)
t0[:, [1, 2]] = 1.

# proposed way, via transposing
dim = 1  # desired dimension to edit
t1 = torch.clone(t)
dims = list(range(t1.ndim))
dims.remove(dim)
dims = [dim] + dims
t1 = t1.permute(dims)
t1[[1, 2]] = 1
t1 = t1.permute(tuple(np.argsort(dims)))

print(torch.allclose(t0, t1))
Output:
True

Yuck!

1 Like

You can use index_select and then index_copy_.

a = torch.rand(5,5,5,5,5, requires_grad=True)
dim = 2
index = torch.tensor([2, 3]) 

def index_select_mult(a, dim, index, mult):
  b = a.clone()
  temp = torch.index_select(a, dim, index)
  temp = temp * mult
  b.index_copy_(dim, index, temp) 
  return b

print(a)
a = index_select_mult(a, dim, index, 10) 
print(a)

Hope this helps :smile:

1 Like

Thanks guys - both approaches seem to work!

1 Like