Custom weight observer for powers of 2

I have a torch.nn system I have developed (full code can be found here) which performs Quantization Aware Training (QAT). I have a very specific use case which requires the scale factors of my nn.Linear activation and weights to be powers of 2 for neuromorphic hardware deployment. Essentially, what I need to do is have a bit-shifting system where integer spike payloads are multiplied by integer weights, and then appropriately scaled down to propogate through to the next layers.

#Integer bit shifting example
scale_factor = activation_scale * weight_scale # Both powers of 2
quotient = calculated spikes >> scale_factor 

Whilst I have been succesful in getting the activations scale factor to be a power of 2, I am struggling to get the weights scale to also be a power of 2. For the activations, I used a custom observer:

class PowerOfTwoMinMaxObserver(MinMaxObserver):
    """
    Observer module for computing the quantization parameters based on the
    running min and max values, with scales as powers of two.

    This observer extends the MinMaxObserver to use scales that are powers of two.
    It overrides the calculate_qparams method to compute the power of two scale.
    """

    def calculate_qparams(self):
        r"""Calculates the quantization parameters with scale as a power of two."""
        min_val, max_val = self.min_val.item(), self.max_val.item()

        # Calculate the scale as the nearest power of two
        max_range = max(abs(min_val), abs(max_val))
        scale = 2 ** math.ceil(math.log2(max_range / (self.quant_max - self.quant_min)))

        # Calculate zero_point as in the base class
        if self.qscheme == torch.per_tensor_symmetric:
            if self.dtype == torch.qint8:
                zero_point = 0
            else:
                zero_point = 128
        else:
            zero_point = self.quant_min - round(min_val / scale)

        # Convert scale and zero_point to PyTorch tensors
        scale = torch.tensor(scale, dtype=torch.float32)
        zero_point = torch.tensor(zero_point, dtype=torch.int64)
        return scale, zero_point

    def extra_repr(self):
        return f"min_val={self.min_val}, max_val={self.max_val}, scale=power of two"

qconfig = quantization.get_default_qat_qconfig('fbgemm')
custom_activation_observer = PowerOfTwoMinMaxObserver.with_args()  # Create an instance of your custom observer

# Set the custom observer for activations in the qconfig
qconfig = torch.quantization.default_qconfig._replace(activation=custom_activation_observer)

Using this observer, the activation scale factor appropriately learns to be 0.0625 (1/0.00625 = 16, a power of 2). However, I run into overflow issues when I use this observer for the weights.

class PowerOfTwoWeightObserver(MinMaxObserver):
    """
    Observer module for computing the quantization parameters based on the
    running min and max values, with scales as powers of two for weights.

    This observer extends the MinMaxObserver to use scales that are powers of two.
    It overrides the calculate_qparams method to compute the power of two scale.
    """

    def __init__(self, bit_width=8, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.dtype = torch.qint8  # Default dtype
        self.bit_width = bit_width  # Specify the bit width

    def calculate_qparams(self):
        r"""Calculates the quantization parameters with scale as a power of two."""
        min_val, max_val = self.min_val.item(), self.max_val.item()

        # Calculate the scale as the nearest power of two
        max_range = max(abs(min_val), abs(max_val))
        scale = 2 ** math.ceil(math.log2(max_range / (self.quant_max - self.quant_min)))

        # Calculate zero_point as in the base class
        if self.qscheme == torch.per_tensor_symmetric:
            zero_point = 0
        else:
            zero_point = self.quant_min - round(min_val / scale)

        # Convert scale and zero_point to PyTorch tensors
        scale = torch.tensor(scale, dtype=torch.float32)
        zero_point = torch.tensor(zero_point, dtype=torch.int64)

        # Adjust the scale based on the specified bit width
        scale = scale / (2 ** (self.bit_width - 1))
        return scale, zero_point

    def extra_repr(self):
        return f"min_val={self.min_val}, max_val={self.max_val}, scale=power of two, bit_width={self.bit_width}"

custom_weight_observer = PowerOfTwoWeightObserver.with_args()
# Set the custom observer for activations in the qconfig
qconfig = torch.quantization.default_qconfig._replace(activation=custom_activation_observer,weight=custom_weight_observer)

The training runs fine, but when it comes to converting the model to be quantized model = quantization.convert(model, inplace=False), I get the following error:

torch/ao/nn/quantized/modules/utils.py", line 40, in _clamp_weights
    qw_int_max = torch.clone(qweight.int_repr()).fill_(max_)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: value cannot be converted to type int8_t without overflow

I think this is most likely due to the calculated scale for the weights being a larger power of two than the activations, causing the weights to exceed the maximum for int8 space.

I have tried lowering the bit_width as low as 2 without any success. I have also already attempted to lower my initial weight values prior to training.

We can approximate the scale_factor by using the following, which works ok:

# Function to approximate division using bit-shifting
def approximate_division(calculated_spikes, divisors):
    # divisors = activation_scale * weight_scale
    # where activation_scale is power of 2, weight_scale is not
    # Find the nearest power of two for the divisors
    nearest_powers = np.ceil(np.log2(divisors)).astype(int)
    # Right-shift dividends by the computed powers
    approximate_quotients = calculated_spikes >> nearest_powers
    
    return approximate_quotients

If anybody has advice or suggestions on how we can extract a weight scale as a power of 2, I’d appreciate it. Thanks in advance!

maybe you can try setting the quant_min/quant_max for the custom power of two observer to (-128, 127)? the default might be the range for quint8 (0, 255) I think