A question about the accuracy of enisum calculation

# 1
a = torch.rand(4, 5, 6, 7)
b = torch.rand(4, 6, 7, 9, 6, 7)
torch.all(torch.einsum('bdnxy->bdxy', torch.einsum('bdhw,bhwnxy->bdnxy', a, b)) == torch.einsum('bdhw,bhwnxy->bdxy', a, b))
tensor(False)

# 2
a = torch.randint(low=1, high=10, size=(4, 5, 6, 7))
b = torch.randint(high=4, size=(4, 6, 7, 9, 6, 7))
torch.all(torch.einsum('bdnxy->bdxy', torch.einsum('bdhw,bhwnxy->bdnxy', a, b)) == torch.einsum('bdhw,bhwnxy->bdxy', a, b))
tensor(True)

# 3
a = torch.rand(4, 5, 6, 7)
b = torch.rand(4, 6, 7, 9, 6, 7)
c = torch.bmm(a.reshape(4, 5, 6*7), b.reshape(4, 6*7, -1)).reshape(4, 5, 9, 6, 7).sum(dim=2)
torch.all(torch.einsum('bdnxy->bdxy', torch.einsum('bdhw,bhwnxy->bdnxy', a, b)) == c)
tensor(True)
torch.all(torch.einsum('bdnxy->bdxy', torch.einsum('bdhw,bhwnxy->bdnxy', a, b)) == torch.einsum('bdhw,bhwnxy->bdxy', a, b))
tensor(False)
torch.all(torch.isclose(torch.einsum('bdnxy->bdxy', torch.einsum('bdhw,bhwnxy->bdnxy', a, b)), torch.einsum('bdhw,bhwnxy->bdxy', a, b)))
tensor(True)

I want to use the code torch.einsum('bdhw,bhwnxy->bdxy', a, b), but there is a strange result.
Its result is different from torch.einsum('bdnxy->bdxy', torch.einsum('bdhw,bhwnxy->bdnxy', a, b)) when the tensors are the float type. I think this is due to the error of floating point calculation…
However, for the float tensors, the result of torch.bmm(a.reshape(4, 5, 6*7), b.reshape(4, 6*7, -1)).reshape(4, 5, 9, 6, 7).sum(dim=2) is the same as torch.einsum('bdnxy->bdxy', torch.einsum('bdhw,bhwnxy->bdnxy', a, b)). I wonder why this is not affected by floating-point computing.

While comparing float/double tensors, use torch.allclose(), instead of ==.

Another related note about float vs. double:

By default, pytorch data type is torch.float. The absolute tolerance level for float operations is around 10^-6. If you set default data type to double using torch.set_default_dtype(torch.double), the tolerance level is <= 10^-15.

Thanks for your reply.
This does work, but I wonder why the two instructions still have the same result without considering the accuracy error.

c = torch.bmm(a.reshape(4, 5, 6*7), b.reshape(4, 6*7, -1)).reshape(4, 5, 9, 6, 7).sum(dim=2)
torch.all(torch.einsum('bdnxy->bdxy', torch.einsum('bdhw,bhwnxy->bdnxy', a, b)) == c)
tensor(True)