IMHO, from the mathematical PoV, this is expected behavior. The integer arithmetic follows the logic of the modular arithmetic, and hence the behavior is as expected. You can notice the same behavior across different frameworks.
For example, in numpy you would get identical results:
a_list = [212, 181, 3, 89, 147]
b_list = [220, 207, 104, 228, 172]
a_np = np.array(a_list, dtype=np.uint8)
b_np = np.array(b_list, dtype=np.uint8)
c_np = a_np * b_np
print(f'c_np={c_np}, c_np.dtype={c_np.dtype}')
# c_np=[ 48 91 56 68 196], c_np.dtype=uint8
c_np_out = np.zeros(shape=a_np.shape, dtype=np.uint16)
c_np_out = np.multiply(a_np, b_np, out=c_np_out)
print(f'c_np_out={c_np_out}, c_np_out.dtype={c_np_out.dtype}')
# c_np_out=[ 48 91 56 68 196], c_np_out.dtype=uint16
As far as the options go, you have several, ranging from simple boundary checks, promotions, and ending with torch customizations.
Not describing the boundary checks, as those usually application specific
Option 1. Manual type promotion
a_list = [212, 181, 3, 89, 147]
b_list = [220, 207, 104, 228, 172]
a_torch = torch.tensor(a_list, dtype=torch.uint8)
b_torch = torch.tensor(b_list, dtype=torch.uint8)
c_torch = a_torch * b_torch.to(torch.int32)
print(f'c_torch={c_torch}')
c_torch_out = torch.zeros(size=a_torch.shape, dtype=torch.int16)
c_torch_out = torch.mul(a_torch, b_torch, out=c_torch_out)
print(f'c_torch={c_torch_out}')
Option 2. Torch customizations
This one is following the documentation here: Extending PyTorch — PyTorch 1.11.0 documentation
# Create a bounded tensor
class BoundedIntTensor:
HANDLED_FUNCTIONS = {}
def __init__(self, data, lo=0, hi=255, **kwargs):
self._t = torch.as_tensor(data, **kwargs)
self.lo = lo
self.hi = hi
def __repr__(self):
return f"BoundedIntTensor ({self.lo}, {self.hi})\n\tdata: {self._t}"
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func not in cls.HANDLED_FUNCTIONS or not all(
issubclass(t, (torch.Tensor, BoundedIntTensor))
for t in types
):
args = [a.tensor() if hasattr(a, 'tensor') else a for a in args]
return func(*args, **kwargs)
return cls.HANDLED_FUNCTIONS[func](*args, **kwargs)
# Implement custom multiplication logic
import functools
def implements(tensor_type, torch_function):
"""Register a torch function override for ScalarTensor"""
@functools.wraps(torch_function)
def decorator(func):
tensor_type.HANDLED_FUNCTIONS[torch_function] = func
return func
return decorator
@implements(BoundedIntTensor, torch.mul)
def mul(self, other):
result = self._t.to(torch.int32) * other._t
result[result > self.hi] = self.hi
result[result < self.lo] = self.lo
return BoundedIntTensor(result.to(self._t.dtype), lo=self.lo, hi=self.hi)
# Check the result
a_list = [212, 12, 3, 89, 147]
b_list = [220, 10, 3, 228, 172]
a = BoundedIntTensor(a_list, lo=0, hi=123)
b = BoundedIntTensor(b_list, lo=0, hi=123)
print(torch.mul(a, b))
# BoundedIntTensor (0, 123)
# data: tensor([123, 120, 9, 123, 123])