If I have two tensors, I can check for equality between the tensors using
torch.all(a == b) or
torch.all(torch.eq(a, b)) or something like that. The tensor == operator returns a tensor of bools, so you need to aggregate across the bools afterward.
If I have a datastructure whose values are tensors (like a list or a dict), how do I compare equality of these datastructures? The structures themselves are not aware of torch, and so don’t know that, when they recursively check equality of their members, they need to do something special for comparing tensors? I’d really like to avoid subclassing the builtin datatypes if possible.
As an aside, does anyone know why it was decided to implement == this way, instead of having == return a bool and having a different elementwise equality?