I’m trying to apply static quantization to a model using a nn.TransformerEncoderLayer
.
But when running the model, I get the following error :
File "/envs/transfo/lib/python3.10/site-packages/torch/nn/modules/transformer.py", line 556, in <genexpr>
elif not all((x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args):
AttributeError: 'function' object has no attribute 'is_cuda'
The model is very basic : embeddings followed by a TransformerEncoderLayer, followed by a linear layer. But I can’t make it work…
Here is a Colab notebook reproducing the issue : Google Colab
Here is the script reproducing the issue :
import torch
from torch import nn
from torch.ao.quantization import qconfig
class Quantformer(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim, dropout_rate, max_seq_len):
super().__init__()
self.hidden_dim = hidden_dim
self.embedding_dim = embedding_dim
self.quant = torch.ao.quantization.QuantStub()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.pos_embedding = nn.Embedding(max_seq_len, embedding_dim)
self.transformer = nn.TransformerEncoderLayer(
d_model=embedding_dim,
nhead=8,
dim_feedforward=hidden_dim,
dropout=dropout_rate,
batch_first=True,
)
self.dropout = nn.Dropout(dropout_rate)
self.fc = nn.Linear(embedding_dim, vocab_size)
self.dequant = torch.ao.quantization.DeQuantStub()
def forward(self, src):
seq_len = src.size(1)
batch_size = src.size(0)
pos_ids = torch.arange(seq_len, dtype=src.dtype, device=src.device).unsqueeze(0).repeat(batch_size, 1)
embeds = self.dropout(self.embedding(src)) + self.pos_embedding(pos_ids)
embeds = self.quant(embeds)
mask = nn.Transformer.generate_square_subsequent_mask(embeds.size(1), device=embeds.device)
out = self.transformer(embeds, src_mask=mask)
lm_logits = self.dropout(self.fc(out))
lm_logits = self.dequant(lm_logits)
return lm_logits
device = torch.device("cpu")
sq_model = Quantformer(
vocab_size=10000,
embedding_dim=128,
hidden_dim=512,
dropout_rate=0,
max_seq_len=10,
).to(device)
sq_model.eval()
sq_model.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack")
sq_model.embedding.qconfig = qconfig.float_qparams_weight_only_qconfig
sq_model.pos_embedding.qconfig = qconfig.float_qparams_weight_only_qconfig
sq_model_prepared = torch.ao.quantization.prepare(sq_model)
x = torch.randint(3, 10000, (1, 10))
sq_model_prepared(x)
squant_model = torch.ao.quantization.convert(sq_model_prepared)
yy = squant_model(x)