When applying FloatFunctional sequentially. Error

I am doing int8 quantization and I need to exchanged the mul operation of pytorch.

answer = a_tensor * 0.2 * b_tensor

I tried to replace the multiplication operations like the below with FloatFunctional’s.

self.ff = nn.quantized.FloatFunctional()

d = self.ff.mul_scalar(a_tensor, 0.2)
answer = self.ff.mul(d, b_tensor)

But, when calls the torch.jit.trace()
I got the exception below.

    answer = self.ff.mul(d, b_tensor)
  File "/root/.pyenv/versions/3.7.1/lib/python3.7/site-packages/torch/nn/quantized/modules/functional_modules.py", line 160, in mul
    r = ops.quantized.mul(x, y, scale=self.scale, zero_point=self.zero_point)
RuntimeError: Mul operands should have same data type.

I printed out the dtype.
print("### d.dtype", a.dtype)
print("### b_tensor.dtype", b_tensor.dtype)

I got the below.
### d.dtype torch.quint8
### b_tensor.dtype torch.float32

Any good solution for this situation?

first, you will need to use one FloatFunctional instance for each invocation.

although the reason for the error is that b_tensor is not quantized, you’ll need to add a quantstub after b_tensor to quantize it before feeding it to floatfunctional.mul