Why does each torch.bool value take up an entire byte?

In [2]: import torch
In [3]: t = ~torch.empty(100000000, dtype=torch.bool, device="cuda")
In [4]: print(torch.cuda.memory_allocated())
100663296
In [5]: print(100663296/100000000)
1.00663296

As shown above, it looks like a large torch.bool tensor will take up roughly 1 byte per item (at least on CUDA. I don’t know how to measure this on CPU).

Why is that? I would have assumed that a bool needs only 1 bit to represent, not 8 (making up a byte).

1 Like

The size of a bool is implementation dependent, and I think 1 byte is a relatively common choice. See also: torch.cuda.BoolTensor uses 8 bits per element, not 1 bit as reported by element_size() · Issue #41571 · pytorch/pytorch · GitHub

Thanks @eqy, that link is very helpful.
As a quick summary (and to verify I got this right), looks like this is currently implemented as-is to facilitate in-memory addressing, but some experts are already looking into fixing this, using features developed for newer 2-bit/4-bit dtypes.

In the mean time, as a workaround, given that PyTorch has a set of bitwise logical operator that can work with uint8, we could conceivably re-organize a 8*N sized bool tensor into a N sized uint8 tensor with the following benefits:

  1. Using less memory
  2. More efficient computations (maybe?)

And the following cons:

  1. Indexing/slicing the compressed tensor at the sub-byte level is complicated without de-compression
  2. Some more complex operations like masking will not be supported

As for the conversion method, my current best idea would be:

  • converting from “uint8bool” to bool: right shift the values according to indexes, and check for 1 mod 2
  • converting from bool to “uint8bool”: convert bool to int, left shift according to index, and sum
In [1]: uint8_tensor = torch.randint(0, 255, (1, 2), dtype=torch.uint8)
In [2]: uint8_tensor
tensor([[67, 27]], dtype=torch.uint8)

In [3]: indexes = torch.arange(8).unsqueeze(0).unsqueeze(0)
In [4]: indexes
tensor([[[0, 1, 2, 3, 4, 5, 6, 7]]])

In [5]: bool_tensor = bool_tensor = (uint8_tensor.unsqueeze(-1)>>indexes)%2==1
In [6]: bool_tensor
tensor([[[ True,  True, False, False, False, False,  True, False],
         [ True,  True, False,  True,  True, False, False, False]]])

In [7]: reconverted = (bool_tensor.byte() << indexes).sum(dim=-1)
In [8]: reconverted
tensor([[67, 27]])

In [9]: uint8_tensor[0,0] | uint8_tensor[0,1]
tensor(91, dtype=torch.uint8)
In [10]: 1+2+8+16+64
91
In [11]: uint8_tensor[0,0] & uint8_tensor[0,1]
tensor(3, dtype=torch.uint8)

The “uint8bool” has the most significant bit at the end, which might not be what most people prefer. In that case, inverting indexes should be enough.