Equivalent of tensorflow.equal()

Given tensor A with shape(5x1) and tensor B with shape(1x5)
In pytorch,
import torch
torch.eq(A, B) will produce C with shape(5x1), but this is not my expectation.

Is there an equivalent of “tensorflow.equal(A, B)” in pytorch that produce a tensor with shape of (5x5) ?

thanks,
metas

a.expand(5, 5).eq(b.expand(5, 5)

1 Like