How do I implement (overload?) >>
for tensors?
@Geremia this is working in pytorch
a=10
print(a << 2)
import torch
a = torch.tensor(10)
print(a<<2)
Result
40
tensor(40)
Was this your question?
You are using int
values, where the bitshift operation is properly defined.
It’s unclear to me what the expected result of a bitshift on floating point numbers is. If you really want to shift the bits, view
the tensor in an integer format, shift it, and view
it back:
x = torch.tensor(1.25)
y = x.view(torch.int32)
print(y)
# tensor(1067450368, dtype=torch.int32)
print(np.binary_repr(y.numpy()))
# 111111101000000000000000000000
y = y << 1
print(y)
# tensor(2134900736, dtype=torch.int32)
print(np.binary_repr(y.numpy()))
# 1111111010000000000000000000000
z = y.view(torch.float32)
print(z)
# tensor(2.5521e+38)
but make sure this is really what you want @Geremia