# A question about the accuracy of enisum calculation

``````# 1
a = torch.rand(4, 5, 6, 7)
b = torch.rand(4, 6, 7, 9, 6, 7)
torch.all(torch.einsum('bdnxy->bdxy', torch.einsum('bdhw,bhwnxy->bdnxy', a, b)) == torch.einsum('bdhw,bhwnxy->bdxy', a, b))
tensor(False)

# 2
a = torch.randint(low=1, high=10, size=(4, 5, 6, 7))
b = torch.randint(high=4, size=(4, 6, 7, 9, 6, 7))
torch.all(torch.einsum('bdnxy->bdxy', torch.einsum('bdhw,bhwnxy->bdnxy', a, b)) == torch.einsum('bdhw,bhwnxy->bdxy', a, b))
tensor(True)

# 3
a = torch.rand(4, 5, 6, 7)
b = torch.rand(4, 6, 7, 9, 6, 7)
c = torch.bmm(a.reshape(4, 5, 6*7), b.reshape(4, 6*7, -1)).reshape(4, 5, 9, 6, 7).sum(dim=2)
torch.all(torch.einsum('bdnxy->bdxy', torch.einsum('bdhw,bhwnxy->bdnxy', a, b)) == c)
tensor(True)
torch.all(torch.einsum('bdnxy->bdxy', torch.einsum('bdhw,bhwnxy->bdnxy', a, b)) == torch.einsum('bdhw,bhwnxy->bdxy', a, b))
tensor(False)
torch.all(torch.isclose(torch.einsum('bdnxy->bdxy', torch.einsum('bdhw,bhwnxy->bdnxy', a, b)), torch.einsum('bdhw,bhwnxy->bdxy', a, b)))
tensor(True)
``````

I want to use the code `torch.einsum('bdhw,bhwnxy->bdxy', a, b)`, but there is a strange result.
Its result is different from `torch.einsum('bdnxy->bdxy', torch.einsum('bdhw,bhwnxy->bdnxy', a, b))` when the tensors are the float type. I think this is due to the error of floating point calculationâ€¦
However, for the float tensors, the result of `torch.bmm(a.reshape(4, 5, 6*7), b.reshape(4, 6*7, -1)).reshape(4, 5, 9, 6, 7).sum(dim=2)` is the same as `torch.einsum('bdnxy->bdxy', torch.einsum('bdhw,bhwnxy->bdnxy', a, b))`. I wonder why this is not affected by floating-point computing.

While comparing `float/double` tensors, use `torch.allclose()`, instead of `==`.

Another related note about float vs. double:

By default, pytorch data type is `torch.float`. The absolute tolerance level for float operations is around `10^-6`. If you set default data type to `double` using `torch.set_default_dtype(torch.double)`, the tolerance level is `<= 10^-15`.

``````c = torch.bmm(a.reshape(4, 5, 6*7), b.reshape(4, 6*7, -1)).reshape(4, 5, 9, 6, 7).sum(dim=2)