Multi-CPU doesn't give a speed-up with non-conv parts

Hi,
I’m trying to quantize torchcrepe module (GitHub - maxrmorrison/torchcrepe: Pytorch implementation of the CREPE pitch tracker)

I used PyTorch’s static quantization.

For all settings, the module size did reduce to about the quarter of the original size. The execution time, however, didn’t decrease always.

With 8 threads,

  • the execution time rather increased (23.349s → 26.471s)

With 1 thread,

  • the execution time did decrease (88.641s → 34.877s)

In order to investigate the root cause for the multi-cpu setting, I found the bottleneck–which was the second conv layer–I quantized that part only in a standalone module.
The results were:

  • the execution time reduced by about 8x for both multi- and single-CPU setting.

Seems like non-conv parts are not helping for speed-up in multi-CPU settings. Would anyone know about this phenomenon? Could you suggest any directions?

Below are the output and the code. Thanks in advance.

Output - Mulit-CPU

CUDA_VISIBLE_DEVICES="" python test.py                                         
========================================================================
=== Before Quantization=================================================
Crepe(
  (conv1): Conv2d(1, 1024, kernel_size=(512, 1), stride=(4, 1))
  (conv1_BN): BatchNorm2d(1024, eps=0.001, momentum=0.0, affine=True, track_running_stats=False)
  (conv2): Conv2d(1024, 128, kernel_size=(64, 1), stride=(1, 1))
  (conv2_BN): BatchNorm2d(128, eps=0.001, momentum=0.0, affine=True, track_running_stats=False)
  (conv3): Conv2d(128, 128, kernel_size=(64, 1), stride=(1, 1))
  (conv3_BN): BatchNorm2d(128, eps=0.001, momentum=0.0, affine=True, track_running_stats=False)
  (conv4): Conv2d(128, 128, kernel_size=(64, 1), stride=(1, 1))
  (conv4_BN): BatchNorm2d(128, eps=0.001, momentum=0.0, affine=True, track_running_stats=False)
  (conv5): Conv2d(128, 256, kernel_size=(64, 1), stride=(1, 1))
  (conv5_BN): BatchNorm2d(256, eps=0.001, momentum=0.0, affine=True, track_running_stats=False)
  (conv6): Conv2d(256, 512, kernel_size=(64, 1), stride=(1, 1))
  (conv6_BN): BatchNorm2d(512, eps=0.001, momentum=0.0, affine=True, track_running_stats=False)
  (classifier): Linear(in_features=2048, out_features=360, bias=True)
  (quant): QuantStub()
  (dequant): DeQuantStub()
)
Size (MB):  88.990
/opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  ../c10/core/TensorImpl.h:1153.)
  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
Execution Time (s): 23.349
/opt/conda/lib/python3.8/site-packages/torch/quantization/observer.py:122: UserWarning: Please use quant_min and quant_max to specify the range for observers.                     reduce_range will be deprecated in a future release of PyTorch.
  warnings.warn(
/opt/conda/lib/python3.8/site-packages/torch/_tensor.py:575: UserWarning: floor_divide is deprecated, and will be removed in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values.
To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  ../aten/src/ATen/native/BinaryOps.cpp:448.)
  return torch.floor_divide(self, other)
=== After Quantization==================================================
Crepe(
  (conv1): QuantizedConv2d(1, 1024, kernel_size=(512, 1), stride=(4, 1), scale=0.027386803179979324, zero_point=65)
  (conv1_BN): QuantizedBatchNorm2d(1024, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): QuantizedConv2d(1024, 128, kernel_size=(64, 1), stride=(1, 1), scale=0.009443704970180988, zero_point=66)
  (conv2_BN): QuantizedBatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): QuantizedConv2d(128, 128, kernel_size=(64, 1), stride=(1, 1), scale=0.003322516568005085, zero_point=64)
  (conv3_BN): QuantizedBatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): QuantizedConv2d(128, 128, kernel_size=(64, 1), stride=(1, 1), scale=0.001186757697723806, zero_point=66)
  (conv4_BN): QuantizedBatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  (conv5): QuantizedConv2d(128, 256, kernel_size=(64, 1), stride=(1, 1), scale=0.0003566498344298452, zero_point=66)
  (conv5_BN): QuantizedBatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  (conv6): QuantizedConv2d(256, 512, kernel_size=(64, 1), stride=(1, 1), scale=0.0001743721222737804, zero_point=64)
  (conv6_BN): QuantizedBatchNorm2d(512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  (classifier): QuantizedLinear(in_features=2048, out_features=360, scale=0.00038746290374547243, zero_point=62, qscheme=torch.per_channel_affine)
  (quant): Quantize(scale=tensor([0.0079]), zero_point=tensor([0]), dtype=torch.quint8)
  (dequant): DeQuantize()
)
Size (MB):  22.340
Execution time (s): 26.471
==================================================
========================================================================
=== Before Quantization=================================================
ConvModel(
  (conv): Conv2d(1024, 128, kernel_size=(64, 1), stride=(1, 1))
  (quant): QuantStub()
  (dequant): DeQuantStub()
)
Size (MB):  33.556
Execution Time (s): 7.850
=== After Quantization==================================================
ConvModel(
  (conv): QuantizedConv2d(1024, 128, kernel_size=(64, 1), stride=(1, 1), scale=0.022107142955064774, zero_point=56)
  (quant): Quantize(scale=tensor([0.0079]), zero_point=tensor([0]), dtype=torch.quint8)
  (dequant): DeQuantize()
)
Size (MB):  8.394
Execution time (s): 0.983
==================================================

Output - Single-CPU

➜ CUDA_VISIBLE_DEVICES="" python test.py
========================================================================
=== Before Quantization=================================================
Crepe(
  (conv1): Conv2d(1, 1024, kernel_size=(512, 1), stride=(4, 1))
  (conv1_BN): BatchNorm2d(1024, eps=0.001, momentum=0.0, affine=True, track_running_stats=False)
  (conv2): Conv2d(1024, 128, kernel_size=(64, 1), stride=(1, 1))
  (conv2_BN): BatchNorm2d(128, eps=0.001, momentum=0.0, affine=True, track_running_stats=False)
  (conv3): Conv2d(128, 128, kernel_size=(64, 1), stride=(1, 1))
  (conv3_BN): BatchNorm2d(128, eps=0.001, momentum=0.0, affine=True, track_running_stats=False)
  (conv4): Conv2d(128, 128, kernel_size=(64, 1), stride=(1, 1))
  (conv4_BN): BatchNorm2d(128, eps=0.001, momentum=0.0, affine=True, track_running_stats=False)
  (conv5): Conv2d(128, 256, kernel_size=(64, 1), stride=(1, 1))
  (conv5_BN): BatchNorm2d(256, eps=0.001, momentum=0.0, affine=True, track_running_stats=False)
  (conv6): Conv2d(256, 512, kernel_size=(64, 1), stride=(1, 1))
  (conv6_BN): BatchNorm2d(512, eps=0.001, momentum=0.0, affine=True, track_running_stats=False)
  (classifier): Linear(in_features=2048, out_features=360, bias=True)
  (quant): QuantStub()
  (dequant): DeQuantStub()
)
Size (MB):  88.990
/opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  ../c10/core/TensorImpl.h:1153.)
  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
Execution Time (s): 88.641
/opt/conda/lib/python3.8/site-packages/torch/quantization/observer.py:122: UserWarning: Please use quant_min and quant_max to specify the range for observers.                     reduce_range will be deprecated in a future release of PyTorch.
  warnings.warn(
/opt/conda/lib/python3.8/site-packages/torch/_tensor.py:575: UserWarning: floor_divide is deprecated, and will be removed in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values.
To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  ../aten/src/ATen/native/BinaryOps.cpp:448.)
  return torch.floor_divide(self, other)
=== After Quantization==================================================
Crepe(
  (conv1): QuantizedConv2d(1, 1024, kernel_size=(512, 1), stride=(4, 1), scale=0.025484636425971985, zero_point=66)
  (conv1_BN): QuantizedBatchNorm2d(1024, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): QuantizedConv2d(1024, 128, kernel_size=(64, 1), stride=(1, 1), scale=0.009398533962666988, zero_point=71)
  (conv2_BN): QuantizedBatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): QuantizedConv2d(128, 128, kernel_size=(64, 1), stride=(1, 1), scale=0.003245722968131304, zero_point=58)
  (conv3_BN): QuantizedBatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): QuantizedConv2d(128, 128, kernel_size=(64, 1), stride=(1, 1), scale=0.001026029814966023, zero_point=57)
  (conv4_BN): QuantizedBatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  (conv5): QuantizedConv2d(128, 256, kernel_size=(64, 1), stride=(1, 1), scale=0.0004001582565251738, zero_point=72)
  (conv5_BN): QuantizedBatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  (conv6): QuantizedConv2d(256, 512, kernel_size=(64, 1), stride=(1, 1), scale=0.00016336666885763407, zero_point=62)
  (conv6_BN): QuantizedBatchNorm2d(512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  (classifier): QuantizedLinear(in_features=2048, out_features=360, scale=0.00038912895251996815, zero_point=61, qscheme=torch.per_channel_affine)
  (quant): Quantize(scale=tensor([0.0079]), zero_point=tensor([0]), dtype=torch.quint8)
  (dequant): DeQuantize()
)
Size (MB):  22.340
Execution time (s): 34.877
==================================================
========================================================================
=== Before Quantization=================================================
ConvModel(
  (conv): Conv2d(1024, 128, kernel_size=(64, 1), stride=(1, 1))
  (quant): QuantStub()
  (dequant): DeQuantStub()
)
Size (MB):  33.556
Execution Time (s): 31.405
=== After Quantization==================================================
ConvModel(
  (conv): QuantizedConv2d(1024, 128, kernel_size=(64, 1), stride=(1, 1), scale=0.021491192281246185, zero_point=65)
  (quant): Quantize(scale=tensor([0.0079]), zero_point=tensor([0]), dtype=torch.quint8)
  (dequant): DeQuantize()
)
Size (MB):  8.394
Execution time (s): 3.334
==================================================

Code (borrowed from the crepe repo)

import torch
import torchcrepe
import functools
import torch.nn.functional as F

import numpy as np
import os
from torch.quantization import QuantStub, DeQuantStub
import time

class Crepe(torch.nn.Module):
    """Crepe model definition"""

    def __init__(self, model='full'):
        super().__init__()

        # Model-specific layer parameters
        if model == 'full':
            in_channels = [1, 1024, 128, 128, 128, 256]
            out_channels = [1024, 128, 128, 128, 256, 512]
            self.in_features = 2048
        elif model == 'tiny':
            in_channels = [1, 128, 16, 16, 16, 32]
            out_channels = [128, 16, 16, 16, 32, 64]
            self.in_features = 256
        else:
            raise ValueError(f'Model {model} is not supported')

        # Shared layer parameters
        kernel_sizes = [(512, 1)] + 5 * [(64, 1)]
        strides = [(4, 1)] + 5 * [(1, 1)]

        # Overload with eps and momentum conversion given by MMdnn
        batch_norm_fn = functools.partial(torch.nn.BatchNorm2d,
                                          eps=0.001,
                                          momentum=0.0,)

        # Layer definitions
        self.conv1 = torch.nn.Conv2d(
            in_channels=in_channels[0],
            out_channels=out_channels[0],
            kernel_size=kernel_sizes[0],
            stride=strides[0])
        self.conv1_BN = batch_norm_fn(
            num_features=out_channels[0])

        self.conv2 = torch.nn.Conv2d(
            in_channels=in_channels[1],
            out_channels=out_channels[1],
            kernel_size=kernel_sizes[1],
            stride=strides[1])
        self.conv2_BN = batch_norm_fn(
            num_features=out_channels[1])

        self.conv3 = torch.nn.Conv2d(
            in_channels=in_channels[2],
            out_channels=out_channels[2],
            kernel_size=kernel_sizes[2],
            stride=strides[2])
        self.conv3_BN = batch_norm_fn(
            num_features=out_channels[2])

        self.conv4 = torch.nn.Conv2d(
            in_channels=in_channels[3],
            out_channels=out_channels[3],
            kernel_size=kernel_sizes[3],
            stride=strides[3])
        self.conv4_BN = batch_norm_fn(
            num_features=out_channels[3])

        self.conv5 = torch.nn.Conv2d(
            in_channels=in_channels[4],
            out_channels=out_channels[4],
            kernel_size=kernel_sizes[4],
            stride=strides[4])
        self.conv5_BN = batch_norm_fn(
            num_features=out_channels[4])

        self.conv6 = torch.nn.Conv2d(
            in_channels=in_channels[5],
            out_channels=out_channels[5],
            kernel_size=kernel_sizes[5],
            stride=strides[5])
        self.conv6_BN = batch_norm_fn(
            num_features=out_channels[5])

        self.classifier = torch.nn.Linear(
            in_features=self.in_features,
            out_features=torchcrepe.PITCH_BINS)

        self.quant = QuantStub()
        self.dequant = DeQuantStub()

    def forward(self, x, embed=False):

        x = self.quant(x)
        # Forward pass through first five layers
        x = self.embed(x)

        if embed:
            return x

        # Forward pass through layer six
        x = self.layer(x, self.conv6, self.conv6_BN)

        # shape=(batch, self.in_features)
        x = x.permute(0, 2, 1, 3).reshape(-1, self.in_features)

        # Compute logits
        x = torch.sigmoid(self.classifier(x))
        x = self.dequant(x)
        return x

    ###########################################################################
    # Forward pass utilities
    ###########################################################################

    def embed(self, x):
        """Map input audio to pitch embedding"""
        # shape=(batch, 1, 1024, 1)
        x = x[:, None, :, None]

        # Forward pass through first five layers
        x = self.layer(x, self.conv1, self.conv1_BN, (0, 0, 254, 254))
        x = self.layer(x, self.conv2, self.conv2_BN)
        x = self.layer(x, self.conv3, self.conv3_BN)
        x = self.layer(x, self.conv4, self.conv4_BN)
        x = self.layer(x, self.conv5, self.conv5_BN)

        return x

    def layer(self, x, conv, batch_norm, padding=(0, 0, 31, 32)):
        """Forward pass through one layer"""
        x = F.pad(x, padding)
        x = conv(x)
        x = F.relu(x)
        x = batch_norm(x)
        return F.max_pool2d(x, (2, 1), (2, 1)).contiguous()

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print(f'Size (MB): {os.path.getsize("temp.p")/1e6 : .3f}')
    os.remove('temp.p')


class ConvModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(1024, 128, kernel_size=(64, 1), stride=(1, 1))
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.conv(x)
        x = self.dequant(x)
        return x

torch.set_num_threads(1)

def run(model, input):
    print("="  * (23 + 49))
    print("=== Before Quantization" + "=" * 49)

    # Original
    print(model)
    print_size_of_model(model)
    start = time.time()
    model(input)
    print(f"Execution Time (s): {time.time() - start :.3f}")

    # Quantized
    model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
    torch.quantization.prepare(model, inplace=True)
    # Calibration
    model(input)
    torch.quantization.convert(model, inplace=True)


    print("=== After Quantization" + "=" * 50)
    print(model)
    print_size_of_model(model)
    start = time.time()
    model(input)
    print(f"Execution time (s): {time.time() - start :.3f}")


    print("=" * (23 + 49))

crepe_input = torch.rand(1299, 1024).contiguous() 
crepe = Crepe()



conv_model_input = torch.rand([1299, 1024, 128, 1])
conv_model = ConvModel()

# Set to eval()
crepe.eval()
conv_model.eval()

# Disable running stats for measurement
crepe.conv1_BN.track_running_stats = False
crepe.conv2_BN.track_running_stats = False
crepe.conv3_BN.track_running_stats = False
crepe.conv4_BN.track_running_stats = False
crepe.conv5_BN.track_running_stats = False
crepe.conv6_BN.track_running_stats = False


run(crepe, crepe_input)
run(conv_model, conv_model_input)

Hi @jasonhuh,
This may have something to do with the shape of the input used in the Conv layers. My guess is that if the shape is not big enough then the work per thread is not significant enough for us to see the wins.

@dskhudia do you have any additional insights here about how FBGEMM may be behaving differently for single vs multi-thread here?