What is the best way to copy_() after advanced indexing?


a = torch.zeros((4, 10, 20, 20))
b = torch.ones((3, 3))
i,j,k,h = np.meshgrid([1,2], [2,1], range(3), range(3), indexing='ij')

I would like to copy the elements of b to a using indexing [i,j,k,h] with broadcast, just like the following code I thought should work

a[i,j,k,h].copy_(b, broadcast=True)

However, as you can see the copy_ method is operated on a cloned storage not the original one. So the copy_ result would not be saved to a. To achieve what I want, I have to repeat() and assign b to a

a[i,j,k,h] = b.repeat(2,2,1,1)

In this case, the a[i,j,k,h] is not cloned. I think it should be the result of __setitem__ instead of __getitem__ in the first case. But I am worried about the efficiency of this code because a and b could be really large. I think the copy_() method with broadcast can be more memory efficient to do this job.

I also found index_copy_() and masked_scatter_(). index_copy_() can select only one dimension which is not what I want, and for masked_scatter_() I have to generate an equal-sized mask, again not so memory efficient.

Can anyone give me suggestions on the best way to achieve this? Thanks!

I found that expand can be used to bypass repeat, so problem solved:

a[i,j,k,h] = b.expand(2,2,3,3)