I am trying to do bitwise operations on a bfloat16 tensor of a random size. I want to do this in a JIT script. Normally, if the input tensor is a, we can use b.view(torch.int16) and then do bitwise masking. However, in a JIT script, this doesn’t work.
Below is the code:
@torch.jit.script
def get_bits(t: torch.Tensor) -> torch.Tensor:
t = torch.tensor([1.0, -1.0, 0.5, -0.5, 2.0], dtype=torch.bfloat16)
b = t.view(torch.int16)
print(b)
return b
And here is the error:
tensor([ 1.0000, -1.0000, 0.5000, -0.5000, 2.0000], dtype=torch.bfloat16)
Traceback (most recent call last):
File "/Users/lib/test.py", line 7, in <module>
t = round_to_fp8_represented_as_int8(x, 2)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
File "/Users/lib/bits.py", line 13, in round_to_fp8_represented_as_int8
) -> torch.Tensor:
b = t.view(torch.int16)
~~~~~~ <--- HERE
print(b)
RuntimeError: shape '[2]' is invalid for input of size 5
What’s the issue here? What is an alternative? This is an issue no matter what tensor size I use.