Recursive Equality

If I have two tensors, I can check for equality between the tensors using torch.all(a == b) or torch.all(torch.eq(a, b)) or something like that. The tensor == operator returns a tensor of bools, so you need to aggregate across the bools afterward.

If I have a datastructure whose values are tensors (like a list or a dict), how do I compare equality of these datastructures? The structures themselves are not aware of torch, and so don’t know that, when they recursively check equality of their members, they need to do something special for comparing tensors? I’d really like to avoid subclassing the builtin datatypes if possible.

As an aside, does anyone know why it was decided to implement == this way, instead of having == return a bool and having a different elementwise equality?

You could use e.g .map to compare the paired elements in both containers.
Also, note that the == operation is usually not recommended for floating point numbers, to you should use torch.allclose, if applicable:

lista = [torch.randn(1) for _ in range(5)]
listb = [a+1e-6 for a in lista]

all(map(lambda x, y: torch.allclose(x, y), lista, listb))
> True

Most likely to stick to the numpy API, which yields the same output.