Given 2 1-dim tensor s, t, I would like to mark all entries of s whose value belong to tensor t as 1, and 0 otherwise. Right now, I am implementing it in the following way
s = torch.tensor([0,3,2,1]) t = torch.tensor([0,2]) i = torch.nonzero(s[..., None]==t)[:, 0] new_s = torch.zeros(s.size()) new_s[i] = 1 // tensor([1., 0., 1., 0.])
However, torch.nonzero(s[…, None]==t) is quite slow for large s, t. It is also memory intensive since it will output as a tensor of shape (len(s), len(t)). (therefore out of memory for gpu) I was wondering is there a better way to achieve this? Thank you very much!