Runtime error "Add operands must be the same size" when using quantized model for inference

Hi every one,

I’m trying to quantize a GAN model using static quantization. The model is quantized successfully but when I’m using the quantized model for inference, it throws Runtime error.

Here is the stack trace

And my module that throw the error:

class AdaIN(nn.Module):
    def __init__(self):
        super().__init__()
        self.quant_1 = torch.quantization.QuantStub()
        self.quant_2 = torch.quantization.QuantStub()
        self.quant_3 = torch.quantization.QuantStub()

        self.add_gamma_functional = nn.quantized.FloatFunctional()
        self.mul_gamma_with_norm = nn.quantized.FloatFunctional()
        self.add_with_beta = nn.quantized.FloatFunctional()
        
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, gamma, x_normalized, beta):
        print('')
        # print(type(gamma), type(beta))
        gamma = self.quant_1(gamma)
        x_normalized = self.quant_2(x_normalized)
        beta = self.quant_3(beta)

        print(gamma.shape, beta.shape, x_normalized.shape)

        # return (1 + gamma) * self.norm(x) + beta
        # return (1 + gamma) * x_normalized + beta
        # Change to below to make module quantizable

        result1 = self.add_gamma_functional.add_scalar(gamma, 1)
        result2 = self.mul_gamma_with_norm.mul(result1, x_normalized)
        result3 = self.add_with_beta.add(result2, beta)
        result3 = self.dequant(result3)
        
        return result3

Is it a bug or I have to manually change the code somehow so that the broadcast operator is used when inference with the quantized model?

Here is the link to Colab the reproduce the issue.

P/s: The shape for gamma, x_nomalized and beta are:
(1, 512, 1, 1)
(1, 512, 16, 16)
(1, 512, 1, 1)

The module is doing only a simple expression:
result = (gamma + 1) * x_normalized + beta

I believe broadcasting for addition is not supported. But there is a simple hack of repeating along some axis by using y.repeat(1, x.shape[1], x.shape[2])

from time import time

x = torch.ones(512, 16, 16)
y = torch.ones(512, 1, 1)

z = x + y

repeat_total_time = 0.0
add_total_time = 0.0
for _ in range(1000):
  qx = torch.quantize_per_tensor(x, 1.0, 0, torch.qint8)
  qy = torch.quantize_per_tensor(y, 1.0, 0, torch.qint8)
  
  start_time = time()
  qy = qy.repeat(1, qx.shape[1], qx.shape[2]).contiguous()
  repeat_total_time += time() - start_time

  start_time = time()
  qz = torch.ops.quantized.add(qx, qy, 1.0, 0)
  add_total_time += time() - start_time


print(f'Time to repeat: {repeat_total_time * 1000:.2f} ms')
print(f'Time to add: {add_total_time * 1000:.2f} ms')
print(f'Overhead: {(repeat_total_time + add_total_time) / add_total_time:.2f}x')

# Time to repeat: 27.84 ms
# Time to add: 73.29 ms
# Overhead: 1.38x

In your code, you would need to make this change:

    def forward(self, gamma, x_normalized, beta):
        # ...
        result1 = self.add_gamma_functional.add_scalar(gamma, 1)
        result2 = self.mul_gamma_with_norm.mul(result1, x_normalized)
        if result2.shape != beta.shape:
             result2 = result2.repeat(1, beta.shape[1], beta.shape[2]).contiguous()
        result3 = self.add_with_beta.add(result2, beta)
        result3 = self.dequant(result3)
        
        return result3

P.S. You might skip the contiguous – I think repeat result is contiguous, but I am not 100% sure

1 Like

Thank you so much @Zafar, I’ll give it a try.