# 1
a = torch.rand(4, 5, 6, 7)
b = torch.rand(4, 6, 7, 9, 6, 7)
torch.all(torch.einsum('bdnxy->bdxy', torch.einsum('bdhw,bhwnxy->bdnxy', a, b)) == torch.einsum('bdhw,bhwnxy->bdxy', a, b))
tensor(False)
# 2
a = torch.randint(low=1, high=10, size=(4, 5, 6, 7))
b = torch.randint(high=4, size=(4, 6, 7, 9, 6, 7))
torch.all(torch.einsum('bdnxy->bdxy', torch.einsum('bdhw,bhwnxy->bdnxy', a, b)) == torch.einsum('bdhw,bhwnxy->bdxy', a, b))
tensor(True)
# 3
a = torch.rand(4, 5, 6, 7)
b = torch.rand(4, 6, 7, 9, 6, 7)
c = torch.bmm(a.reshape(4, 5, 6*7), b.reshape(4, 6*7, -1)).reshape(4, 5, 9, 6, 7).sum(dim=2)
torch.all(torch.einsum('bdnxy->bdxy', torch.einsum('bdhw,bhwnxy->bdnxy', a, b)) == c)
tensor(True)
torch.all(torch.einsum('bdnxy->bdxy', torch.einsum('bdhw,bhwnxy->bdnxy', a, b)) == torch.einsum('bdhw,bhwnxy->bdxy', a, b))
tensor(False)
torch.all(torch.isclose(torch.einsum('bdnxy->bdxy', torch.einsum('bdhw,bhwnxy->bdnxy', a, b)), torch.einsum('bdhw,bhwnxy->bdxy', a, b)))
tensor(True)
I want to use the code torch.einsum('bdhw,bhwnxy->bdxy', a, b)
, but there is a strange result.
Its result is different from torch.einsum('bdnxy->bdxy', torch.einsum('bdhw,bhwnxy->bdnxy', a, b))
when the tensors are the float type. I think this is due to the error of floating point calculation…
However, for the float tensors, the result of torch.bmm(a.reshape(4, 5, 6*7), b.reshape(4, 6*7, -1)).reshape(4, 5, 9, 6, 7).sum(dim=2)
is the same as torch.einsum('bdnxy->bdxy', torch.einsum('bdhw,bhwnxy->bdnxy', a, b))
. I wonder why this is not affected by floating-point computing.