While comparing float/double tensors, use torch.allclose(), instead of ==.
Another related note about float vs. double:
By default, pytorch data type is torch.float. The absolute tolerance level for float operations is around 10^-6. If you set default data type to double using torch.set_default_dtype(torch.double), the tolerance level is <= 10^-15.