ab-10
(Armins Stepanjans)
1
Is there an efficient way to apply Python’s in
(__contains__
) elementwise to a torch.tensor
?
This is the behaviour I’m looking for:
>> some_values = [3, 5, 8]
>> m = torch.randint(0, 10, (2,2))
>> m
[[0, 9],
[8, 5]]
>> m in some_values
[[False, False],
[True, True]]
Thanks in advance!
(m == some_values[0]) | (m == some_values[1]) | (m == some_values[2])
is one way to do it. I’m not sure how efficient it is.
ab-10
(Armins Stepanjans)
3
I ended up doing it this way:
r
# of rows in matrix
c
# of columns in matrix
n
# of elements in the collection you’re comparing to
m = torch.randint(0, 10, (r, c))
some_values = torch.randint(0, 10, (n))
new_m = m.unsqueeze(2).expand(r, c, n)
new_val = some_values.repeat(r, c).view(r, c, n)
mask = torch.where(new_m == new_val, torch.ones(r, c, n), torch.zeros(r, c, n))
output = mask.sum(2)
I’m unsure about how efficient it is as compared to your method though.