Hi,
I need to detect equality in my tensors. I am using two equivalent ways to multiply 4D metrics, which is element-wise then sum, and torch.bmm. The results are the same, but when checking with the ("==
") operator, they dont seem to be.
import torch
a = torch.rand(3,2,100)
b = torch.randn(3,4,2)
a_expanded = a.unsqueeze(1).expand(-1, b.shape[1], -1, -1)
mul1 = (b.unsqueeze(3) * a_expanded).sum(dim = 2)
mul2 = torch.bmm(b, a)
Now, if we do: mul1==mul2
The result is:
tensor([[[1, 1, 0, ..., 1, 1, 1],
[0, 0, 1, ..., 1, 1, 1],
[0, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1]],
[[0, 1, 0, ..., 1, 1, 1],
[1, 0, 1, ..., 1, 1, 1],
[0, 0, 1, ..., 1, 1, 1],
[1, 0, 0, ..., 1, 1, 1]],
[[1, 1, 0, ..., 1, 1, 1],
[1, 0, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1]]], dtype=torch.uint8)
Examining the elements where it gave 0:
They are same. What is wrong?