Pairwise equality + reduction

I am trying to compute pairwise equality of all elements of 2 1D tensors. Then I want to know how many times each element of y occurs in x. Right now I am using the following:

x = x.view(-1,1)
y = y.view(1,-1)
(x==y).sum(dim=0)

Now this works fine but the intermediate x==y is quite large and it will not fit into the GPU for large x and y. Is there a way to bypass creating the x==y matrix given that I only care about the final reduced sum?

Any help would be appreciated.

1 Like

You could trade memory for compute by using loops and perform the comparison without the broadcasting for the sub-elements (including the sum). I don’t know, of there would be another better way “between” these approaches.