I have a torch.nn system I have developed (full code can be found here) which performs Quantization Aware Training (QAT). I have a very specific use case which requires the scale factors of my nn.Linear activation and weights to be powers of 2 for neuromorphic hardware deployment. Essentially, what I need to do is have a bit-shifting system where integer spike payloads are multiplied by integer weights, and then appropriately scaled down to propogate through to the next layers.

```
#Integer bit shifting example
scale_factor = activation_scale * weight_scale # Both powers of 2
quotient = calculated spikes >> scale_factor
```

Whilst I have been succesful in getting the activations scale factor to be a power of 2, I am struggling to get the weights scale to also be a power of 2. For the activations, I used a custom observer:

```
class PowerOfTwoMinMaxObserver(MinMaxObserver):
"""
Observer module for computing the quantization parameters based on the
running min and max values, with scales as powers of two.
This observer extends the MinMaxObserver to use scales that are powers of two.
It overrides the calculate_qparams method to compute the power of two scale.
"""
def calculate_qparams(self):
r"""Calculates the quantization parameters with scale as a power of two."""
min_val, max_val = self.min_val.item(), self.max_val.item()
# Calculate the scale as the nearest power of two
max_range = max(abs(min_val), abs(max_val))
scale = 2 ** math.ceil(math.log2(max_range / (self.quant_max - self.quant_min)))
# Calculate zero_point as in the base class
if self.qscheme == torch.per_tensor_symmetric:
if self.dtype == torch.qint8:
zero_point = 0
else:
zero_point = 128
else:
zero_point = self.quant_min - round(min_val / scale)
# Convert scale and zero_point to PyTorch tensors
scale = torch.tensor(scale, dtype=torch.float32)
zero_point = torch.tensor(zero_point, dtype=torch.int64)
return scale, zero_point
def extra_repr(self):
return f"min_val={self.min_val}, max_val={self.max_val}, scale=power of two"
qconfig = quantization.get_default_qat_qconfig('fbgemm')
custom_activation_observer = PowerOfTwoMinMaxObserver.with_args() # Create an instance of your custom observer
# Set the custom observer for activations in the qconfig
qconfig = torch.quantization.default_qconfig._replace(activation=custom_activation_observer)
```

Using this observer, the activation scale factor appropriately learns to be 0.0625 (1/0.00625 = 16, a power of 2). However, I run into overflow issues when I use this observer for the weights.

```
class PowerOfTwoWeightObserver(MinMaxObserver):
"""
Observer module for computing the quantization parameters based on the
running min and max values, with scales as powers of two for weights.
This observer extends the MinMaxObserver to use scales that are powers of two.
It overrides the calculate_qparams method to compute the power of two scale.
"""
def __init__(self, bit_width=8, *args, **kwargs):
super().__init__(*args, **kwargs)
self.dtype = torch.qint8 # Default dtype
self.bit_width = bit_width # Specify the bit width
def calculate_qparams(self):
r"""Calculates the quantization parameters with scale as a power of two."""
min_val, max_val = self.min_val.item(), self.max_val.item()
# Calculate the scale as the nearest power of two
max_range = max(abs(min_val), abs(max_val))
scale = 2 ** math.ceil(math.log2(max_range / (self.quant_max - self.quant_min)))
# Calculate zero_point as in the base class
if self.qscheme == torch.per_tensor_symmetric:
zero_point = 0
else:
zero_point = self.quant_min - round(min_val / scale)
# Convert scale and zero_point to PyTorch tensors
scale = torch.tensor(scale, dtype=torch.float32)
zero_point = torch.tensor(zero_point, dtype=torch.int64)
# Adjust the scale based on the specified bit width
scale = scale / (2 ** (self.bit_width - 1))
return scale, zero_point
def extra_repr(self):
return f"min_val={self.min_val}, max_val={self.max_val}, scale=power of two, bit_width={self.bit_width}"
custom_weight_observer = PowerOfTwoWeightObserver.with_args()
# Set the custom observer for activations in the qconfig
qconfig = torch.quantization.default_qconfig._replace(activation=custom_activation_observer,weight=custom_weight_observer)
```

The training runs fine, but when it comes to converting the model to be quantized `model = quantization.convert(model, inplace=False)`

, I get the following error:

```
torch/ao/nn/quantized/modules/utils.py", line 40, in _clamp_weights
qw_int_max = torch.clone(qweight.int_repr()).fill_(max_)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: value cannot be converted to type int8_t without overflow
```

I think this is most likely due to the calculated scale for the weights being a larger power of two than the activations, causing the weights to exceed the maximum for `int8`

space.

I have tried lowering the `bit_width`

as low as 2 without any success. I have also already attempted to lower my initial weight values prior to training.

We can approximate the `scale_factor`

by using the following, which works ok:

```
# Function to approximate division using bit-shifting
def approximate_division(calculated_spikes, divisors):
# divisors = activation_scale * weight_scale
# where activation_scale is power of 2, weight_scale is not
# Find the nearest power of two for the divisors
nearest_powers = np.ceil(np.log2(divisors)).astype(int)
# Right-shift dividends by the computed powers
approximate_quotients = calculated_spikes >> nearest_powers
return approximate_quotients
```

If anybody has advice or suggestions on how we can extract a weight scale as a power of 2, I’d appreciate it. Thanks in advance!