The issue is that you are using torch.equal
on a float tensor. You can use
torch.allclose(y0[1], y1[0], atol=1e-6)
instead (which evaluates to True
).
I am not entirely sure what you are trying to do here but note that
print(torch.equal(x0[0], x1[1]))
is False
as well.
Hope that helps