Multiplying two uint8 tensors without overflow

I want to multiply two uint8, for example

>>> a = torch.randint(low=0,high=255, size=(5,), dtype=torch.uint8)
>>> b = torch.randint(low=0,high=255, size=(5,), dtype=torch.uint8)
>>> a
tensor([212, 181,   3,  89, 147], dtype=torch.uint8)
>>> b
tensor([220, 207, 104, 228, 172], dtype=torch.uint8)
>>> a * b
tensor([ 48,  91,  56,  68, 196], dtype=torch.uint8)

However, the return type by default is also uint8, thus it causes overflow issue. The result for the first element in the array is 212 * 220 % 256 = 48, but I want 212 * 220 = 46640. I know it is impossible to hold the result in dtype uint8, but I cannot find any way to store the result in a tensor with high bits. I have also tried the following:

>>> c = torch.empty(5, dtype=torch.int16)
>>> torch.mul(a, b, out=c)
tensor([ 48,  91,  56,  68, 196], dtype=torch.int16)

Though the results is stored in an int16 tensor now, it is still truncated.

I know I can first convert the uint8 tensor to int16, then do the multiplication, but that would cause extra overhead and I want to avoid that. Therefore I am wondering if it is possible to store the result in a higher bit tensor without truncating. Thanks

We probably can change the quantization to always set to max/min when overflowing.

To clarify, the current question is not about quantization as defined by the torch.ao, but rather the low-precision dtypes. However, I think the quantized dtypes should be aware of the current behavior

Hi Zafar,

I agree this question is not about quantization, but I cannot find a subject that’s more appropriate. I thought this question should be frequently dealt when doing int8 arithmetics for quantization. If you have any suggestions about the subject, I am glad to change it.

I am asking this question from a pure math angle. Is there any way to get the results without overflow? For example, is there any python or cpp API I can use?

Thanks

IMHO, from the mathematical PoV, this is expected behavior. The integer arithmetic follows the logic of the modular arithmetic, and hence the behavior is as expected. You can notice the same behavior across different frameworks.

For example, in numpy you would get identical results:

a_list = [212, 181,   3,  89, 147]
b_list = [220, 207, 104, 228, 172]

a_np = np.array(a_list, dtype=np.uint8)
b_np = np.array(b_list, dtype=np.uint8)

c_np = a_np * b_np
print(f'c_np={c_np}, c_np.dtype={c_np.dtype}')
# c_np=[ 48  91  56  68 196], c_np.dtype=uint8

c_np_out = np.zeros(shape=a_np.shape, dtype=np.uint16)
c_np_out = np.multiply(a_np, b_np, out=c_np_out)
print(f'c_np_out={c_np_out}, c_np_out.dtype={c_np_out.dtype}')
# c_np_out=[ 48  91  56  68 196], c_np_out.dtype=uint16

As far as the options go, you have several, ranging from simple boundary checks, promotions, and ending with torch customizations.

Not describing the boundary checks, as those usually application specific

Option 1. Manual type promotion

a_list = [212, 181,   3,  89, 147]
b_list = [220, 207, 104, 228, 172]

a_torch = torch.tensor(a_list, dtype=torch.uint8)
b_torch = torch.tensor(b_list, dtype=torch.uint8)

c_torch = a_torch * b_torch.to(torch.int32)
print(f'c_torch={c_torch}')

c_torch_out = torch.zeros(size=a_torch.shape, dtype=torch.int16)
c_torch_out = torch.mul(a_torch, b_torch, out=c_torch_out)
print(f'c_torch={c_torch_out}')

Option 2. Torch customizations

This one is following the documentation here: Extending PyTorch — PyTorch 1.11.0 documentation

# Create a bounded tensor
class BoundedIntTensor:
    HANDLED_FUNCTIONS = {}
    def __init__(self, data, lo=0, hi=255, **kwargs):
        self._t = torch.as_tensor(data, **kwargs)
        self.lo = lo
        self.hi = hi

    def __repr__(self):
        return f"BoundedIntTensor ({self.lo}, {self.hi})\n\tdata: {self._t}"

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        if func not in cls.HANDLED_FUNCTIONS or not all(
            issubclass(t, (torch.Tensor, BoundedIntTensor))
            for t in types
        ):
            args = [a.tensor() if hasattr(a, 'tensor') else a for a in args]
            return func(*args, **kwargs)
        return cls.HANDLED_FUNCTIONS[func](*args, **kwargs)
    
# Implement custom multiplication logic
import functools
def implements(tensor_type, torch_function):
    """Register a torch function override for ScalarTensor"""
    @functools.wraps(torch_function)
    def decorator(func):
        tensor_type.HANDLED_FUNCTIONS[torch_function] = func
        return func
    return decorator

@implements(BoundedIntTensor, torch.mul)
def mul(self, other):
    result = self._t.to(torch.int32) * other._t
    result[result > self.hi] = self.hi
    result[result < self.lo] = self.lo
    return BoundedIntTensor(result.to(self._t.dtype), lo=self.lo, hi=self.hi)

# Check the result
a_list = [212, 12, 3,  89, 147]
b_list = [220, 10, 3, 228, 172]

a = BoundedIntTensor(a_list, lo=0, hi=123)
b = BoundedIntTensor(b_list, lo=0, hi=123)

print(torch.mul(a, b))
# BoundedIntTensor (0, 123)
# 	data: tensor([123, 120,   9, 123, 123])

Sorry for the confusion, I understand this is an expected behavior. I was just hoping to find an alternative way to bypass this overflow issue and get the mathematically correct answer.

So grateful for your detailed reply! If I understand correctly, the underlying arithmetics for the both options are multiplications of int32. However, as I mentioned earlier, the reason I wanted to this is because I wanted to take advantage of the speedup of uint8 multiplication. If I do the multiplication after promoting the uint8 tensor to int32, I am not sure how efficient it will be.

One reason I posted this issue at the quantization channel is because this issue should be universal in quantization. For example, how is the quantized linear layer is computed?

If we don’t consider zero point, scaling factor, and bias for now, the linear layer boils down to multiplication of two int8 tensor (weight and input). To get the correct answer, the result of this multiplication cannot be overflow. I thought the multiplication is done in int8 and the result is stored in a tensor with higher bits and then converted back into int8. I am curious how is this step efficiently implemented to get the ideal speedup?

the reason I wanted to this is because I wanted to take advantage of the speedup of uint8 multiplication

IMHO, To take advantage of uint8 arithmetic, you would have to make some assumptions, such as input/output ranges. Otherwise, there is no way knowing if multiplication is overflowing.

I am curious how is this step efficiently implemented to get the ideal speedup?

Quantization doesn’t make assumptions, and forces the user to provide the output scale and zero_point (see here). Here is an example:

a_list = [212, 181,   3,  89, 147]
b_list = [220, 207, 3, 228, 172]

a_torch = torch.tensor(a_list, dtype=torch.uint8)
b_torch = torch.tensor(b_list, dtype=torch.uint8)

qa_torch = torch.quantize_per_tensor(a_torch.to(torch.float), scale=1.0, zero_point=0, dtype=torch.quint8)
qb_torch = torch.quantize_per_tensor(b_torch.to(torch.float), scale=1.0, zero_point=0, dtype=torch.quint8)

print(f'- qa_torch={qa_torch}') 
print(f'- qa_torch.int_repr={qa_torch.int_repr()}')
print(f'- qb_torch={qb_torch}')
print(f'- qb_torch.int_repr={qb_torch.int_repr()}')
# - qa_torch=tensor([212., 181.,   3.,  89., 147.], size=(5,), dtype=torch.quint8,
#        quantization_scheme=torch.per_tensor_affine, scale=1.0, zero_point=0)
# - qa_torch.int_repr=tensor([212, 181,   3,  89, 147], dtype=torch.uint8)
# - qb_torch=tensor([220., 207.,   3., 228., 172.], size=(5,), dtype=torch.quint8,
#        quantization_scheme=torch.per_tensor_affine, scale=1.0, zero_point=0)
# - qb_torch.int_repr=tensor([220, 207,   3, 228, 172], dtype=torch.uint8)

qc_torch = torch.ops.quantized.mul(qa_torch, qb_torch, scale=1.0, zero_point=0)
print(f'- qc_torch={qc_torch}')
print(f'- qc_torch.int_repr={qc_torch.int_repr()}')
# - qc_torch=tensor([255., 255.,   9., 255., 255.], size=(5,), dtype=torch.quint8,
#        quantization_scheme=torch.per_tensor_affine, scale=1.0, zero_point=0)
# - qc_torch.int_repr=tensor([255, 255,   9, 255, 255], dtype=torch.uint8)

One reason I posted this issue at the quantization channel is because this issue should be universal in quantization. For example, how is the quantized linear layer is computed?

If we don’t consider zero point, scaling factor, and bias for now, the linear layer boils down to multiplication of two int8 tensor (weight and input). To get the correct answer, the result of this multiplication cannot be overflow. I thought the multiplication is done in int8 and the result is stored in a tensor with higher bits and then converted back into int8. I am curious how is this step efficiently implemented to get the ideal speedup?

Here is how the operation is performed: FBGEMM/OutputProcessing-inl.h at 9d7c48a65419d0350f9e9e72f31e05bfe37e85a4 · pytorch/FBGEMM · GitHub

Basically everything is performed in int32; then, when considering the scale, it gets converted to float, and finally it gets rounded to long right before clipping.

Thanks, Zafar and ParGG!

I think my previous understanding about quantization is incorrect. zero point and scaling factor are also indivisible part of quantization.

As ParGG said, if all the int8 arithmetics in FBGEMM is also performed with int32, I should not worry about the efficiency issue of type promotion.

This is the first time I posted on PyTorch forum, and I find you guys are very helpful! I really appreciate your help