Use torch.eq() only for some values

Is there a way to use torch.eq() or a similar function to compute element-based equality but only for some elements? Let’s say I need to know how many 1s are equal in the two tensors but I don’t care about other numbers.

Any idea how to do this?

I guess you can do: a.eq(b).__and__(a.eq(1)).