```
# 1
a = torch.rand(4, 5, 6, 7)
b = torch.rand(4, 6, 7, 9, 6, 7)
torch.all(torch.einsum('bdnxy->bdxy', torch.einsum('bdhw,bhwnxy->bdnxy', a, b)) == torch.einsum('bdhw,bhwnxy->bdxy', a, b))
tensor(False)
# 2
a = torch.randint(low=1, high=10, size=(4, 5, 6, 7))
b = torch.randint(high=4, size=(4, 6, 7, 9, 6, 7))
torch.all(torch.einsum('bdnxy->bdxy', torch.einsum('bdhw,bhwnxy->bdnxy', a, b)) == torch.einsum('bdhw,bhwnxy->bdxy', a, b))
tensor(True)
# 3
a = torch.rand(4, 5, 6, 7)
b = torch.rand(4, 6, 7, 9, 6, 7)
c = torch.bmm(a.reshape(4, 5, 6*7), b.reshape(4, 6*7, -1)).reshape(4, 5, 9, 6, 7).sum(dim=2)
torch.all(torch.einsum('bdnxy->bdxy', torch.einsum('bdhw,bhwnxy->bdnxy', a, b)) == c)
tensor(True)
torch.all(torch.einsum('bdnxy->bdxy', torch.einsum('bdhw,bhwnxy->bdnxy', a, b)) == torch.einsum('bdhw,bhwnxy->bdxy', a, b))
tensor(False)
torch.all(torch.isclose(torch.einsum('bdnxy->bdxy', torch.einsum('bdhw,bhwnxy->bdnxy', a, b)), torch.einsum('bdhw,bhwnxy->bdxy', a, b)))
tensor(True)
```

I want to use the code `torch.einsum('bdhw,bhwnxy->bdxy', a, b)`

, but there is a strange result.

Its result is different from `torch.einsum('bdnxy->bdxy', torch.einsum('bdhw,bhwnxy->bdnxy', a, b))`

when the tensors are the float type. I think this is due to the error of floating point calculationâ€¦

However, for the float tensors, the result of `torch.bmm(a.reshape(4, 5, 6*7), b.reshape(4, 6*7, -1)).reshape(4, 5, 9, 6, 7).sum(dim=2)`

is the same as `torch.einsum('bdnxy->bdxy', torch.einsum('bdhw,bhwnxy->bdnxy', a, b))`

. I wonder **why this is not affected by floating-point computing**.