.view function in JIT scripts

I am trying to do bitwise operations on a bfloat16 tensor of a random size. I want to do this in a JIT script. Normally, if the input tensor is a, we can use b.view(torch.int16) and then do bitwise masking. However, in a JIT script, this doesn’t work.

Below is the code:

@torch.jit.script
def get_bits(t: torch.Tensor) -> torch.Tensor:
  t = torch.tensor([1.0, -1.0, 0.5, -0.5, 2.0], dtype=torch.bfloat16)
  b = t.view(torch.int16)
  print(b)
  return b

And here is the error:

tensor([ 1.0000, -1.0000,  0.5000, -0.5000,  2.0000], dtype=torch.bfloat16)
Traceback (most recent call last):
  File "/Users/lib/test.py", line 7, in <module>
    t = round_to_fp8_represented_as_int8(x, 2)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "/Users/lib/bits.py", line 13, in round_to_fp8_represented_as_int8
) -> torch.Tensor:

    b = t.view(torch.int16)
        ~~~~~~ <--- HERE
    print(b)
RuntimeError: shape '[2]' is invalid for input of size 5

What’s the issue here? What is an alternative? This is an issue no matter what tensor size I use.

The issue seems to be with the .view() method in JIT script. As an alternative, try using .to(torch.int16) for type casting, and then perform the bitwise operations. If that doesn’t work, consider using a different approach to manipulate the tensor’s bits without changing its view. Also, ensure that the size of the tensor matches the expected shape when converting.