I have the following tensors:
# 2 x 5 x 3
a = torch.tensor([[[1, 3, 2], [7, 9, 8], [13, 15, 14], [19, 21, 20], [25, 27, 26]], [[31, 33, 32], [37, 39, 38], [43, 45, 46], [49, 51, 50], [55, 57, 56]]])
b = torch.tensor([[[19, 21, 20], [7, 9, 8], [13, 15, 14], [1, 3, 2], [25, 27, 26]], [[55, 57, 56], [31, 33, 32], [37, 39, 38], [43, 45, 44], [49, 51, 50]]])
I’d like to obtain the following:
# 2 x 5 x 5
c = torch.tensor([[[0, 0, 0, 1, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [1, 0, 0, 0, 0], [0, 0, 0, 0, 1]], [[0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1], [1, 0, 0, 0, 0]]])
where, the values of c
are obtained using the function:
# compare each value of a with each value of b.
# If they are equal, then set the value at those indices as 1 in c, otherwise set it to 0.
# For instance, a's first value [1, 3, 2] is compared with b's first value [19, 21, 20]. Since, they are not equal, I will set c's first value as 0.
def test(x, y):
if torch.all(x==y):
return 1
else:
return 0
How do I do this without looping over a
and b
?