You want to copy_ directly into the index of the dst tensor. In particular, since the dtypes are different, it is not possible to directly map them. (Y[…] = X)
So I used copy, but the copy to the indexed (Y[:, index]) doesn’t work as expected.
Can anyone confirm if I’m using it wrong?
import torch
X = torch.randn([2, 360, 5000], device="cuda", dtype = torch.float32)
Y = torch.zeros([64, 2, 360, 5000], device="cuda", dtype = torch.float16)
index = torch.randint(0, 2, [2, 360, 5000], device="cuda") == 1
Y[:, index].copy_(X[None, index], non_blocking=True)
print(torch.allclose(Y[0, index].float(), X[index])) # False