Compatibility Issue: Wav2Vec2 QAT with PyTorch 2 Export

I am facing a compatibility issue between the Wav2Vec2 model and the PyTorch 2 Export library when performing Quantization-Aware Training (QAT). According to the documentation, torch.export should be used before applying the ai_edge_torch conversion method. However, I am encountering challenges in ensuring proper compatibility.

here is my experimental code 
class CustomQuantConfig:
    def __init__(self):
        # Default quantization settings
        backend = 'x86'
        self.static_qconfig = torch.ao.quantization.get_default_qat_qconfig(backend)
        self.dynamic_qconfig = torch.ao.quantization.default_dynamic_qconfig
        self.no_qconfig = None
        
        # Specific configs for different layer types
        self.conv_qconfig = self.static_qconfig
        self.linear_qconfig = self.static_qconfig
        self.attention_qconfig = self.dynamic_qconfig

    def get_qconfig_mapping(self):
        """
        Returns a QConfigMapping object for the model
        """
        qconfig_mapping = torch.ao.quantization.QConfigMapping()
        
        # Set global config
        qconfig_mapping.global_qconfig = self.static_qconfig
        
        # Set specific configs
        qconfig_mapping.set_module_name_regex(".*feature_extractor.*conv.*", self.conv_qconfig)
        qconfig_mapping.set_module_name_regex(".*feature_projection.*linear.*", self.linear_qconfig)
        qconfig_mapping.set_module_name_regex(".*attention.*", self.attention_qconfig)
        qconfig_mapping.set_module_name_regex(".*layer_norm.*", None)
        qconfig_mapping.set_module_name_regex(".*batch_norm.*", None)
        
        return qconfig_mapping

class QuantWav2Vec2FeatureExtractor(nn.Module):
    def __init__(self, original_extractor: nn.Module):
        super().__init__()
        self.conv_layers = original_extractor.conv_layers
        for param in self.parameters():
            param.requires_grad = False
        self.qconfig = None
    
    def forward(self, input_values):
        # Use unsqueeze instead of None indexing
        hidden_states = input_values.unsqueeze(1)  # Add channel dimension
        for conv_layer in self.conv_layers:
            hidden_states = conv_layer(hidden_states)
        return hidden_states

class QuantWav2Vec2Model(nn.Module):
    def __init__(self, original_model):
        super().__init__()
        # Create feature extractor with frozen parameters
        self.feature_extractor = QuantWav2Vec2FeatureExtractor(original_model.wav2vec2.feature_extractor)
        
        # Copy other components
        self.feature_projection = original_model.wav2vec2.feature_projection
        self.encoder = original_model.wav2vec2.encoder
        self.dropout = original_model.dropout
        self.lm_head = original_model.lm_head
        self.config = original_model.config
        
        # Initialize quantization configuration
        self.qconfig_handler = CustomQuantConfig()
        self.qconfig_mapping = self.qconfig_handler.get_qconfig_mapping()
        
        # Apply configuration
        self._configure_quantization()
    
    def _configure_quantization(self):
        """Configure quantization for different parts of the model"""
        # Feature extractor is frozen and not quantized
        self.feature_extractor.qconfig = None
        
        # Configure feature projection
        for name, module in self.feature_projection.named_modules():
            if isinstance(module, nn.Linear):
                module.qconfig = self.qconfig_handler.linear_qconfig
            elif isinstance(module, nn.LayerNorm):
                module.qconfig = None
        
        # Configure encoder
        for name, module in self.encoder.named_modules():
            if "pos_conv_embed" in name:
                module.qconfig = None
            elif isinstance(module, nn.Linear):
                if "attention" in name:
                    module.qconfig = self.qconfig_handler.attention_qconfig
                else:
                    module.qconfig = self.qconfig_handler.linear_qconfig
            elif isinstance(module, nn.LayerNorm):
                module.qconfig = None
        
        # Configure lm_head
        self.lm_head.qconfig = self.qconfig_handler.dynamic_qconfig
    
    def forward(self, input_values, attention_mask=None, labels=None):
        # Extract features
        hidden_states = self.feature_extractor(input_values)
        hidden_states = hidden_states.transpose(1, 2)
        
        # Project features
        hidden_states = self.feature_projection(hidden_states)
        
        # Create attention mask if needed
        if attention_mask is None:
            attention_mask = torch.ones_like(input_values[:, :hidden_states.shape[1]])
        
        # Extended attention mask for transformer
        extended_attention_mask = (1.0 - attention_mask[:, None, None, :]) * -10000.0
        
        # Apply encoder
        encoder_outputs = self.encoder(
            hidden_states,
            attention_mask=extended_attention_mask,
        )
        
        if isinstance(encoder_outputs, tuple):
            hidden_states = encoder_outputs[0]
        else:
            hidden_states = encoder_outputs
            
        hidden_states = self.dropout(hidden_states)
        
        # Final projection
        logits = self.lm_head(hidden_states)
        
        loss = None
        if labels is not None:
            input_lengths = torch.full(
                size=(input_values.shape[0],),
                fill_value=hidden_states.shape[1],
                dtype=torch.long,
                device=input_values.device
            )
            label_lengths = torch.sum(labels >= 0, dim=-1)
            
            log_probs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
            loss = F.ctc_loss(
                log_probs.transpose(0, 1),
                labels,
                input_lengths,
                label_lengths,
                blank=self.config.pad_token_id,
                reduction='mean',
                zero_infinity=True
            )
        
        return (loss, logits) if loss is not None else logits

def prepare_model_for_quantization(model, example_inputs):
    """
    Prepare model for quantization with proper configurations
    """
    # Get quantization configuration
    qconfig_mapping = model.qconfig_mapping
    
    # Create quantizer with the QConfig mapping
    quantizer = X86InductorQuantizer()
    quantizer.set_global(qconfig_mapping)
    
    # Make sure input is a tensor, not a tuple
    if isinstance(example_inputs, tuple):
        example_inputs = example_inputs[0]
    
    # Export model with static shapes
    exported_model = torch._dynamo.export(
        model,
        example_inputs,
        aten_graph=True,
        tracing_mode="real",
        assume_static_by_default=True
    )[0]
    
    # Prepare for QAT
    prepared_model = prepare_qat_pt2e(exported_model, quantizer)
    return prepared_model

def train_quantized_model(model, train_dataloader, num_epochs=3, learning_rate=1e-4):
    """
    Train the quantization-aware model
    """
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    model.train()
    
    for epoch in range(num_epochs):
        total_loss = 0
        for batch_idx, batch in enumerate(train_dataloader):
            input_values = batch["input_values"].to(torch.float32)
            labels = batch["labels"]
            
            optimizer.zero_grad()
            loss, _ = model(input_values=input_values, labels=labels)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
            if batch_idx % 50 == 0:
                print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}, Loss: {loss.item():.4f}")
        
        avg_loss = total_loss / len(train_dataloader)
        print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")
    
    return model

def convert_and_save_model(model, save_path):
    """
    Convert the trained QAT model and save it
    """
    converted_model = convert_pt2e(model)
    torch.save(converted_model.state_dict(), save_path)
    return converted_model

and this is the issue I am facing :



File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/base.py", line 343, in call_method
    unimplemented(f"call_method {self} {name} {args} {kwargs}")
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 297, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: call_method GetAttrVariable(TupleVariable(length=2), shape) __getitem__ (ConstantVariable(),) {}

from user code:
   File "/home/jupyter/Wav2vec2_qat/wav2vec2_with_ai_edge_torch.py", line 111, in forward
    attention_mask = torch.ones_like(input_values[:, :hidden_states.shape[1]])

@andrewor can you take a look at this question please?