Is there an efficient way to apply Python’s
__contains__) elementwise to a
This is the behaviour I’m looking for:
>> some_values = [3, 5, 8]
>> m = torch.randint(0, 10, (2,2))
>> m in some_values
Thanks in advance!
(m == some_values) | (m == some_values) | (m == some_values) is one way to do it. I’m not sure how efficient it is.
I ended up doing it this way:
r # of rows in matrix
c # of columns in matrix
n # of elements in the collection you’re comparing to
m = torch.randint(0, 10, (r, c))
some_values = torch.randint(0, 10, (n))
new_m = m.unsqueeze(2).expand(r, c, n)
new_val = some_values.repeat(r, c).view(r, c, n)
mask = torch.where(new_m == new_val, torch.ones(r, c, n), torch.zeros(r, c, n))
output = mask.sum(2)
I’m unsure about how efficient it is as compared to your method though.