Is it planned to support nn.Embeddings quantization?

First of all, I would like to thank you for the awesome torch.quantization . But at the moment, the quantization of embeddings is not supported, although ususally it’s one of the biggest (in terms of size) parts of the model (in NLP).
I tried to use nn.Embeddings as nn.Linear because they have a very similar nature, but get the following error:

RuntimeError: Could not run 'aten::index_select' with arguments from the 'QuantizedCPUTensorId' backend. 'aten::index_select' is only available for these backends: [CPUTensorId, CUDATensorId, SparseCPUTensorId, SparseCUDATensorId, VariableTensorId].

So I’m interested whether it’s planned to support nn.Embeddings quantization?

I think @Zafar is working on this

1 Like

@jerryzh168
Hey, I saw that Embedding quantization was added in 1.7.0, but I can’t reproduce it with the latest version. I tried both static and dynamic quantization. Can you please share a code snippet that converts Embeddings to int8?

When I try static quantization, I’m getting:
AssertionError: The only supported dtype for nnq.Embedding is torch.quint8

I am having the same error. In Version 1.6 embeddings were not a problem or were ignored.

@skurzhanskyi and @pintonos, could you share a small repro to reproduce the error?

A small repro would be great. Currently can you try setting the qconfig for the embedding module to float_qparams_weight_only_qconfig? We only support float qparams quantization for the Embedding layers. If you use the default qconfig for Embedding layers, you may run into this error.

Small example to demonstrate this

class EmbeddingWithLinear(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12)
        self.fc = torch.nn.Linear(5, 5)
        self.emb.qconfig = float_qparams_weight_only_qconfig
        self.qconfig = default_qconfig

    def forward(self, indices, linear_in):
        return self.emb(indices), self.fc(linear_in)

Thanks for the example!

However according to this file float_qparams_weight_only_qconfig is part of torch.quantization. With pytorch 1.7.1 CPU version torch.quantization.float_qparams_weight_only_qconfig cannot be imported!

Is this configuration not published yet?

Looks like this was renamed after 1.7. If you switch to nightly then you should be able to use it.

If you’re using 1.7 version try using float_qparams_dynamic_qconfig instead.

I updated to the nightly version. Using this code for testing:

import torch
import numpy as np
from torch.quantization import QuantStub, DeQuantStub, float_qparams_weight_only_qconfig, default_qconfig


class EmbeddingWithLinear(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12)
        self.fc = torch.nn.Linear(5, 5)
        self.emb.qconfig = float_qparams_weight_only_qconfig
        self.qconfig = default_qconfig

    def forward(self, indices, linear_in):
        return self.emb(indices), self.fc(linear_in)


# create a model instance
model_fp32 = EmbeddingWithLinear()

indices_fp32 = torch.tensor(np.array([1, 3, 4, 5])).long()
input_fp32 = torch.randn(5, 5)

model_fp32.eval()
res = model_fp32(indices_fp32, input_fp32)
print(res)

model_fp32_prepared = torch.quantization.prepare(model_fp32)

model_fp32_prepared(indices_fp32, input_fp32)


model_int8 = torch.quantization.convert(model_fp32_prepared)
res = model_int8(indices_fp32, input_fp32)
print(res)

Getting this error:

RuntimeError: Could not run 'quantized::linear' with arguments from the 'CPU' backend. ...

Isn’t that exactly your example code?

Hi @pintonos,

The error is because you are trying to pass in a FP32 input tensor to a quantized operator. If you change the model to include quant/dequant stubs it should work as expected

class EmbeddingWithLinear(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12)
        self.fc = torch.nn.Linear(5, 5)
        self.emb.qconfig = float_qparams_weight_only_qconfig
        self.qconfig = default_qconfig
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

    def forward(self, indices, linear_in):
        a = self.emb(indices)
        x = self.quant(linear_in)
        quant = self.fc(x)
        return a, self.dequant(quant)
1 Like

Thanks so far!

Seems to work now, but I am getting an error while slicing my indices tensor after the model was calibrated and quantized. The slicing worked every time before the quantization itself.

emb(Xi[:, i - self.num, :])

Error:

RuntimeError: Expect weight, indices, and offsets to be contiguous.

Using torch.LongTensor(128, 1).random_(0, 10) which leads to the same tensor shape as input works, but the tensor slicing seems to make problems.

Any suggenstions?

It seems like this was recently modified in https://github.com/pytorch/pytorch/pull/48993. The operator expects the values passed in to the embedding operator to be contiguous.
You could check the inputs by doing x.is_contiguous() and call x.contiguous() if they are not.

I’ll file an issue to support this in the operator itself.

1 Like

Sorry to reopen this topic.

Is it somehow possible to skip embedding layers to be quantized in post-static quantization? So that only linear layers for instance are getting quantized, as it was with earlier versions?

It is possible to set the qconfig of Embeddings to None if you wish to skip quantizing them. For example

class EmbeddingWithLinear(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12)
                self.fc = torch.nn.Linear(5, 5)
                self.emb.qconfig = None
                self.qconfig = default_qconfig

            def forward(self, indices, linear_in):
                return self.emb(indices), self.fc(linear_in)
3 Likes

Hi
I have the same error. I try to quantize the DETR.

Error:

“AssertionError: The only supported dtype for nnq.Embedding is torch.quint8”.

class DETR(nn.Module):
“”" This is the DETR module that performs object detection “”"
def init(self, backbone, transformer, num_classes, num_queries, aux_loss=False):
“”" Initializes the model.
Parameters:
backbone: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
num_classes: number of object classes
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
DETR can detect in a single image. For COCO, we recommend 100 queries.
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
“”"
super().init()
self.num_queries = num_queries
self.transformer = transformer
hidden_dim = transformer.d_model
self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
self.query_embed = nn.Embedding(num_queries, hidden_dim)
self.query_embed.qconfig = None # --------------------------------------------------
self.qconfig = default_qconfig
self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1)
self.backbone = backbone
self.aux_loss = aux_loss

def forward(self, samples: NestedTensor):
    """Â The forward expects a NestedTensor, which consists of:
           - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
           - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels

        It returns a dict with the following elements:
           - "pred_logits": the classification logits (including no-object) for all queries.
                            Shape= [batch_size x num_queries x (num_classes + 1)]
           - "pred_boxes": The normalized boxes coordinates for all queries, represented as
                           (center_x, center_y, height, width). These values are normalized in [0, 1],
                           relative to the size of each individual image (disregarding possible padding).
                           See PostProcess for information on how to retrieve the unnormalized bounding box.
           - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
                            dictionnaries containing the two above keys for each decoder layer.
    """
    samples = self.quant(samples) #------------------------------------------------
    if isinstance(samples, (list, torch.Tensor)):
        samples = nested_tensor_from_tensor_list(samples)
    features, pos = self.backbone(samples)

    src, mask = features[-1].decompose()
    assert mask is not None
    hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]

    outputs_class = self.class_embed(hs)
    outputs_coord = self.bbox_embed(hs).sigmoid()
    out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
    if self.aux_loss:
        out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
    out = self.dequant(out) #-------------------------------------------------------------
    return out

hi @supriyar
I tried this way and got the same error as before.
I’m trying to quantize DETR model but I faced with this error. Do you think I should try something?