Applying a function on batched tensor

I have the following tensors:

# 2 x 5 x 3
a = torch.tensor([[[1, 3, 2], [7, 9, 8], [13, 15, 14], [19, 21, 20], [25, 27, 26]], [[31, 33, 32], [37, 39, 38], [43, 45, 46], [49, 51, 50], [55, 57, 56]]])
b = torch.tensor([[[19, 21, 20], [7, 9, 8], [13, 15, 14], [1, 3, 2], [25, 27, 26]], [[55, 57, 56], [31, 33, 32], [37, 39, 38], [43, 45, 44], [49, 51, 50]]])

I’d like to obtain the following:

# 2 x 5 x 5
c = torch.tensor([[[0, 0, 0, 1, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [1, 0, 0, 0, 0], [0, 0, 0, 0, 1]], [[0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1], [1, 0, 0, 0, 0]]])

where, the values of c are obtained using the function:

# compare each value of a with each value of b. 
# If they are equal, then set the value at those indices as 1 in c, otherwise set it to 0. 
# For instance, a's first value [1, 3, 2] is compared with b's first value [19, 21, 20]. Since, they are not equal, I will set c's first value as 0. 
def test(x, y):
    if torch.all(x==y):
        return 1
    else:
        return 0

How do I do this without looping over a and b?

Hi Seven!

I don’t understand your use case. Among other things, the second row
of your tensor, c, seems to indicate five matches between the second
rows of a and b, but I only see four.

Could you post a complete, runnable script that illustrates your use case
by computing your desired output?

This approach – which may or may not be applicable to your specific
use case – illustrates a loop-free way to compare the rows of a and b:

>>> import torch
>>> torch.__version__
'1.12.0'
>>> a = torch.tensor([[[1, 3, 2], [7, 9, 8], [13, 15, 14], [19, 21, 20], [25, 27, 26]], [[31, 33, 32], [37, 39, 38], [43, 45, 46], [49, 51, 50], [55, 57, 56]]])
>>> b = torch.tensor([[[19, 21, 20], [7, 9, 8], [13, 15, 14], [1, 3, 2], [25, 27, 26]], [[55, 57, 56], [31, 33, 32], [37, 39, 38], [43, 45, 44], [49, 51, 50]]])
>>> c = torch.tensor([[[0, 0, 0, 1, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [1, 0, 0, 0, 0], [0, 0, 0, 0, 1]], [[0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1], [1, 0, 0, 0, 0]]])
>>> cB = (a.unsqueeze (1) == b.unsqueeze (2)).prod (dim = 3)
>>> torch.equal (c[0], cB[0])
True

Best.

K. Frank

Thanks for your response, K. Frank. So, here is what I’m after (See R in image):

o = torch.tensor([[[1, 3, 2], [7, 9, 8], [13, 15, 14], [19, 21, 20], [25, 27, 26]], [[31, 33, 32], [37, 39, 38], [43, 45, 44], [49, 51, 50], [55, 57, 56]]])
p = torch.tensor([[[19, 21, 20], [7, 9, 8], [13, 15, 14], [1, 3, 2], [25, 27, 26]], [[55, 57, 56], [31, 33, 32], [37, 39, 38], [43, 45, 44], [49, 51, 50]]])

# this is O' in image
o_prime = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 0.11]])

# this is P' in image
p_prime = torch.tensor([[1.1, 1.2, 1.3, 1.4, 1.5], [1.6, 1.7, 1.8, 1.9, 1.11]])

# this is R (this is what I need)
r = torch.tensor([[[0, 0, 0, 6.1, 0], [0, 24.2, 0, 0, 0], [0, 0, 42.3, 0, 0], [60.4, 0, 0, 0, 0], [0, 0, 0, 0, 78.5]], [[0, 96.6, 0, 0, 0], [0, 0, 114.7, 0, 0], [0, 0, 0, 132.8, 0], [0, 0, 0, 0, 150.9], [168.11, 0, 0, 0, 0]]])

Let me know if you need more clarification.