Thanks @eqy, that link is very helpful.
As a quick summary (and to verify I got this right), looks like this is currently implemented as-is to facilitate in-memory addressing, but some experts are already looking into fixing this, using features developed for newer 2-bit/4-bit dtypes.
In the mean time, as a workaround, given that PyTorch has a set of bitwise logical operator that can work with uint8, we could conceivably re-organize a 8*N sized bool tensor into a N sized uint8 tensor with the following benefits:
- Using less memory
- More efficient computations (maybe?)
And the following cons:
- Indexing/slicing the compressed tensor at the sub-byte level is complicated without de-compression
- Some more complex operations like masking will not be supported
As for the conversion method, my current best idea would be:
- converting from “uint8bool” to bool: right shift the values according to indexes, and check for 1 mod 2
- converting from bool to “uint8bool”: convert bool to int, left shift according to index, and sum
In [1]: uint8_tensor = torch.randint(0, 255, (1, 2), dtype=torch.uint8)
In [2]: uint8_tensor
tensor([[67, 27]], dtype=torch.uint8)
In [3]: indexes = torch.arange(8).unsqueeze(0).unsqueeze(0)
In [4]: indexes
tensor([[[0, 1, 2, 3, 4, 5, 6, 7]]])
In [5]: bool_tensor = bool_tensor = (uint8_tensor.unsqueeze(-1)>>indexes)%2==1
In [6]: bool_tensor
tensor([[[ True, True, False, False, False, False, True, False],
[ True, True, False, True, True, False, False, False]]])
In [7]: reconverted = (bool_tensor.byte() << indexes).sum(dim=-1)
In [8]: reconverted
tensor([[67, 27]])
In [9]: uint8_tensor[0,0] | uint8_tensor[0,1]
tensor(91, dtype=torch.uint8)
In [10]: 1+2+8+16+64
91
In [11]: uint8_tensor[0,0] & uint8_tensor[0,1]
tensor(3, dtype=torch.uint8)
The “uint8bool” has the most significant bit at the end, which might not be what most people prefer. In that case, inverting indexes
should be enough.