_joker
1
I’m looking for an equivalent of TensorFlow in Pytorch
binary_test = tf.cast(binary > 0.5, tf.float32)
I tried;
binary_test = binary.type(float32)
binary_test = torch.where(binary_test > 0.5, x, y)
where, x = Tensor([1]), y = Tensor([0])
On doing so the python kernel crashes.
The posted code should generally work:
binary = torch.randint(0, 2, (10,)).byte()
binary_test = binary.type(torch.float32)
binary_test = torch.where(binary_test > 0.5, torch.tensor([1.]), torch.tensor([0.]))
Note that I had to fix the float32
to torch.float32
and am not sure if it’s a copy-paste issue or if you might be running into this error.
Could you post the values of binary
, which create this error please?