Taylor-series Approximation for Sigmiod in Integer

Hi Mates,

I’m developing a Taylor series approximated sigmoid that can operate in the integer domain. I’m quantizing the series coefficient and input to INT-32 (input_scale * weight_scale) and holding the output scale to dequantize from INT-64 (64’bit due to bit-overflow). I like to validate my method and the math behind it and hear suggestions and ways to improve it.

Thanks
Vimal William

C++ Code for ApproxSigmoid:

#include <cmath>
#include <cstdint>
#include <pybind11/pybind11.h>
#include <stdexcept>
#include <torch/torch.h>
#include <vector>

class approxSigmoid {

public:
  approxSigmoid(float inScale, float outScale)
      : inScale_(inScale), outScale_(outScale) {
    initCoeff();
  }

  torch::Tensor forward(torch::Tensor &in) {
    auto tensorI32 = in.mul(1.0f / inScale_)
                         .round()
                         .clamp(INT32_MIN, INT32_MAX)
                         .to(torch::kInt32);
    auto inData = tensorI32.data_ptr<int32_t>();
    auto outTensor = torch::empty_like(tensorI32, torch::kFloat32);
    auto outData = outTensor.data_ptr<float>();

    auto size = tensorI32.numel();
    for (int i = 0; i < size; ++i) {
      int64_t tmp = computeSeries(inData[i]);
      outData[i] = static_cast<float>(tmp) * outScale_;
    }

    return outTensor;
  }

private:
  float inScale_, outScale_;
  std::vector<int32_t> qCoeff;
  std::vector<float> SIGMOID_COEFF{0.5f, 1.0f / 4.0f, 0.0f, -1.0f / 48.0f};

  void initCoeff() {
    for (size_t i = 0; i < SIGMOID_COEFF.size(); i++) {
      auto tmp = SIGMOID_COEFF[i] * inScale_;
      auto rNc = std::clamp(std::round(tmp), static_cast<float>(INT32_MIN), static_cast<float>(INT32_MAX));
      qCoeff.push_back(static_cast<int32_t>(rNc));
    }
  }

  int64_t computeSeries(int32_t x) {
    int64_t result = qCoeff[0];
    int64_t power = x;

    for (size_t i = 1; i < qCoeff.size(); i++) {
      result += qCoeff[i] * power;
      power *= x;  // Correct power accumulation
    }

    return result;
  }
};

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  pybind11::class_<approxSigmoid>(m, "approx_sigmoid")
      .def(pybind11::init<float, float>(), pybind11::arg("in_scale"),
           pybind11::arg("out_scale"))
      .def("forward", &approxSigmoid::forward);
}

Python Test Code:

import torch
import numpy as np
from torchvision import models
import matplotlib.pyplot as plt
import matplotlib as mpl

# Adjust the chunk size for Agg backend
mpl.rcParams['agg.path.chunksize'] = 10000  # Adjust as needed

try:
    from approx_sigmoid import approx_sigmoid
except ImportError:
    print(f"Import Failed")
    exit()

def quantize(x):
    dtype = torch.int8
    r_max, r_min = x.max().item(), x.min().item()
    q_max, q_min = torch.iinfo(dtype).max, torch.iinfo(dtype).min

    scale = (r_max - r_min) / (q_max - q_min)
    q_tensor = torch.clamp(torch.round(x / scale), q_min, q_max)
    return q_tensor, scale

def main():
    model = models.vgg16()
    tensor = torch.rand([3, 224, 224]) 

    for module in model.modules():
        if isinstance(module, torch.nn.Conv2d):
            q_weight, w_scale = quantize(module.weight)
            q_input, i_scale = quantize(tensor)

            out = module(tensor)
            flat_out = torch.flatten(out)

            # https://discuss.pytorch.org/t/is-bias-quantized-while-doing-pytorch-static-quantization/146416/5
            q_output, o_scale = quantize(out)
            scale_i32 = (w_scale * i_scale) 

            sigmoid_out = torch.sigmoid(flat_out[:1000])
            q_sig, s_scale = quantize(sigmoid_out)
            print(scale_i32)
            sigmoid = approx_sigmoid(scale_i32, s_scale)

            approx_out = sigmoid.forward(flat_out[:1000])

            # Plotting
            plt.figure(figsize=(10, 6))
            plt.plot(flat_out.detach().numpy()[:1000], sigmoid_out.detach().numpy(), label="True Sigmoid", color="blue")
            plt.plot(flat_out.detach().numpy()[:1000], approx_out.detach().numpy(), label="Approx. Sigmoid", color="orange", linestyle="--")
            plt.title("Sigmoid Approximation")
            plt.xlabel("Input")
            plt.ylabel("Output")
            plt.legend()
            plt.grid(True)
            plt.savefig("test.png")

            break  

if __name__ == "__main__":
    main()

Hi William!

As an aside, sigmoid (x) has a horizontal asymptote approaching zero
as x approaches -inf and a horizontal asymptote approaching one as
x approaches +inf. These asymptotes are not well approximated by
polynomials (such as Taylor series).

If you can afford an (integer) division, you might be better off with a
rational-function (the ratio of two polynomials) approximation, such as
a Padé approximant.

Note that sigmoid() is essentially the same as tanh() – just scaled and
shifted. It would probably be a little easier to work with and reason about
tanh().

Best.

K. Frank