Assigning values to tensor sliced by indices and mask

Hello everyone,

I am searching for way to make this assignment to a slice of a tensor work, without copying data (and also understand why my version doesn’t work):

import torch

indices = torch.tensor([2,0])
lengths_new = torch.tensor([1,2])
lengths_old = torch.tensor([2,2,3])
tensor_new = torch.tensor([[3,1,2,4], [2,1,3,4]])
tensor_old = torch.tensor([[1,2,3,4], [1,2,3,4], [2,1,3,4]])

mask = lengths_new < lengths_old[indices]
print("old value", tensor_old[indices][mask])
print("new value", tensor_new[mask])
tensor_old[indices][mask] = tensor_new[mask]
print("result value", tensor_old[indices][mask])
print("old value", lengths_old[indices][mask])
print("new value", lengths_new[mask])
lengths_old[indices][mask] = lengths_new[mask]
print("result value", lengths_old[indices][mask])

The output is this:

old value tensor([[2, 1, 3, 4]])
new value tensor([[3, 1, 2, 4]])
result value tensor([[2, 1, 3, 4]])
old value tensor([3])
new value tensor([1])
result value tensor([3])

My desired output would be that the result value is the same as the new value. Any suggestions how to achieve this are very welcome.

Best regards
H4ns

If I understand your code correctly you are currently working on a copy of the data since you are indexing the tensor sequentially.
This should work and yields the same results for new and result:

indices = torch.tensor([2,0])
lengths_new = torch.tensor([1,2])
lengths_old = torch.tensor([2,2,3])
tensor_new = torch.tensor([[3,1,2,4], [2,1,3,4]])
tensor_old = torch.tensor([[1,2,3,4], [1,2,3,4], [2,1,3,4]])

mask = lengths_new < lengths_old[indices]
print("old value", tensor_old[indices][mask])
print("new value", tensor_new[mask])
tensor_old[indices[mask]] = tensor_new[mask]
print("result value", tensor_old[indices][mask])
print("old value", lengths_old[indices][mask])
print("new value", lengths_new[mask])
lengths_old[indices[mask]] = lengths_new[mask]
print("result value", lengths_old[indices][mask])
1 Like

Ah very nice thank you! :slight_smile: