I’m looking for an equivalent of TensorFlow in Pytorch
binary_test = tf.cast(binary > 0.5, tf.float32)
binary_test = binary.type(float32)
binary_test = torch.where(binary_test > 0.5, x, y)
x = Tensor(), y = Tensor()
On doing so the python kernel crashes.
The posted code should generally work:
binary = torch.randint(0, 2, (10,)).byte()
binary_test = binary.type(torch.float32)
binary_test = torch.where(binary_test > 0.5, torch.tensor([1.]), torch.tensor([0.]))
Note that I had to fix the
torch.float32 and am not sure if it’s a copy-paste issue or if you might be running into this error.
Could you post the values of
binary, which create this error please?