I am facing a compatibility issue between the Wav2Vec2 model and the PyTorch 2 Export library when performing Quantization-Aware Training (QAT). According to the documentation, torch.export
should be used before applying the ai_edge_torch
conversion method. However, I am encountering challenges in ensuring proper compatibility.
here is my experimental code
class CustomQuantConfig:
def __init__(self):
# Default quantization settings
backend = 'x86'
self.static_qconfig = torch.ao.quantization.get_default_qat_qconfig(backend)
self.dynamic_qconfig = torch.ao.quantization.default_dynamic_qconfig
self.no_qconfig = None
# Specific configs for different layer types
self.conv_qconfig = self.static_qconfig
self.linear_qconfig = self.static_qconfig
self.attention_qconfig = self.dynamic_qconfig
def get_qconfig_mapping(self):
"""
Returns a QConfigMapping object for the model
"""
qconfig_mapping = torch.ao.quantization.QConfigMapping()
# Set global config
qconfig_mapping.global_qconfig = self.static_qconfig
# Set specific configs
qconfig_mapping.set_module_name_regex(".*feature_extractor.*conv.*", self.conv_qconfig)
qconfig_mapping.set_module_name_regex(".*feature_projection.*linear.*", self.linear_qconfig)
qconfig_mapping.set_module_name_regex(".*attention.*", self.attention_qconfig)
qconfig_mapping.set_module_name_regex(".*layer_norm.*", None)
qconfig_mapping.set_module_name_regex(".*batch_norm.*", None)
return qconfig_mapping
class QuantWav2Vec2FeatureExtractor(nn.Module):
def __init__(self, original_extractor: nn.Module):
super().__init__()
self.conv_layers = original_extractor.conv_layers
for param in self.parameters():
param.requires_grad = False
self.qconfig = None
def forward(self, input_values):
# Use unsqueeze instead of None indexing
hidden_states = input_values.unsqueeze(1) # Add channel dimension
for conv_layer in self.conv_layers:
hidden_states = conv_layer(hidden_states)
return hidden_states
class QuantWav2Vec2Model(nn.Module):
def __init__(self, original_model):
super().__init__()
# Create feature extractor with frozen parameters
self.feature_extractor = QuantWav2Vec2FeatureExtractor(original_model.wav2vec2.feature_extractor)
# Copy other components
self.feature_projection = original_model.wav2vec2.feature_projection
self.encoder = original_model.wav2vec2.encoder
self.dropout = original_model.dropout
self.lm_head = original_model.lm_head
self.config = original_model.config
# Initialize quantization configuration
self.qconfig_handler = CustomQuantConfig()
self.qconfig_mapping = self.qconfig_handler.get_qconfig_mapping()
# Apply configuration
self._configure_quantization()
def _configure_quantization(self):
"""Configure quantization for different parts of the model"""
# Feature extractor is frozen and not quantized
self.feature_extractor.qconfig = None
# Configure feature projection
for name, module in self.feature_projection.named_modules():
if isinstance(module, nn.Linear):
module.qconfig = self.qconfig_handler.linear_qconfig
elif isinstance(module, nn.LayerNorm):
module.qconfig = None
# Configure encoder
for name, module in self.encoder.named_modules():
if "pos_conv_embed" in name:
module.qconfig = None
elif isinstance(module, nn.Linear):
if "attention" in name:
module.qconfig = self.qconfig_handler.attention_qconfig
else:
module.qconfig = self.qconfig_handler.linear_qconfig
elif isinstance(module, nn.LayerNorm):
module.qconfig = None
# Configure lm_head
self.lm_head.qconfig = self.qconfig_handler.dynamic_qconfig
def forward(self, input_values, attention_mask=None, labels=None):
# Extract features
hidden_states = self.feature_extractor(input_values)
hidden_states = hidden_states.transpose(1, 2)
# Project features
hidden_states = self.feature_projection(hidden_states)
# Create attention mask if needed
if attention_mask is None:
attention_mask = torch.ones_like(input_values[:, :hidden_states.shape[1]])
# Extended attention mask for transformer
extended_attention_mask = (1.0 - attention_mask[:, None, None, :]) * -10000.0
# Apply encoder
encoder_outputs = self.encoder(
hidden_states,
attention_mask=extended_attention_mask,
)
if isinstance(encoder_outputs, tuple):
hidden_states = encoder_outputs[0]
else:
hidden_states = encoder_outputs
hidden_states = self.dropout(hidden_states)
# Final projection
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
input_lengths = torch.full(
size=(input_values.shape[0],),
fill_value=hidden_states.shape[1],
dtype=torch.long,
device=input_values.device
)
label_lengths = torch.sum(labels >= 0, dim=-1)
log_probs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
loss = F.ctc_loss(
log_probs.transpose(0, 1),
labels,
input_lengths,
label_lengths,
blank=self.config.pad_token_id,
reduction='mean',
zero_infinity=True
)
return (loss, logits) if loss is not None else logits
def prepare_model_for_quantization(model, example_inputs):
"""
Prepare model for quantization with proper configurations
"""
# Get quantization configuration
qconfig_mapping = model.qconfig_mapping
# Create quantizer with the QConfig mapping
quantizer = X86InductorQuantizer()
quantizer.set_global(qconfig_mapping)
# Make sure input is a tensor, not a tuple
if isinstance(example_inputs, tuple):
example_inputs = example_inputs[0]
# Export model with static shapes
exported_model = torch._dynamo.export(
model,
example_inputs,
aten_graph=True,
tracing_mode="real",
assume_static_by_default=True
)[0]
# Prepare for QAT
prepared_model = prepare_qat_pt2e(exported_model, quantizer)
return prepared_model
def train_quantized_model(model, train_dataloader, num_epochs=3, learning_rate=1e-4):
"""
Train the quantization-aware model
"""
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
model.train()
for epoch in range(num_epochs):
total_loss = 0
for batch_idx, batch in enumerate(train_dataloader):
input_values = batch["input_values"].to(torch.float32)
labels = batch["labels"]
optimizer.zero_grad()
loss, _ = model(input_values=input_values, labels=labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
if batch_idx % 50 == 0:
print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}, Loss: {loss.item():.4f}")
avg_loss = total_loss / len(train_dataloader)
print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")
return model
def convert_and_save_model(model, save_path):
"""
Convert the trained QAT model and save it
"""
converted_model = convert_pt2e(model)
torch.save(converted_model.state_dict(), save_path)
return converted_model
and this is the issue I am facing :
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/base.py", line 343, in call_method
unimplemented(f"call_method {self} {name} {args} {kwargs}")
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 297, in unimplemented
raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: call_method GetAttrVariable(TupleVariable(length=2), shape) __getitem__ (ConstantVariable(),) {}
from user code:
File "/home/jupyter/Wav2vec2_qat/wav2vec2_with_ai_edge_torch.py", line 111, in forward
attention_mask = torch.ones_like(input_values[:, :hidden_states.shape[1]])