I have a code snippet to use index copy:
import torch
torch.set_default_device('cpu')
data = torch.zeros((5, 5))
part = torch.tensor([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]], dtype=torch.float32)
index = torch.tensor([-1, 3])
data.index_copy_(0, index, part)
print(data)
And I want -1
to be ignored, rather than throwing an error. Is it possible?
I can use mask = index != -1
, and use data.index_copy_(0, index[mask], part[mask])
, but part[mask]
has variable length, and does not work for cudagraph.