Pytorch onnx.dynamo_export problem with transformer square mask

Hi, I am trying to export my model to ONNX using dynamo export. It’s a transformer-based Seq2Seq model. Exporting the encoder is problem-free. However, when I attempt to export the decoder, dynamo_export crashes with:
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: It appears that you’re trying to get a value out of symbolic int/float whose value is data-dependent (and thus we do not know the true value.) The expression we were trying to evaluate is Eq(i0, 1) (unhinted: Eq(i0, 1)).

I tried to investigate further and discovered that the issue lies with the decoder’s square mask. If I invoke the export without the mask, there is no crash. I created a minimalistic example:

import torch                                                                                                                                                                                                       
import torch.nn as nn                                                                                                                                                                                              
                                                                                                                                                                                                                   
class TransformerDecoder(nn.Module):                                                                                                                                                                               
    def __init__( self, nhead=16, num_encoder_layers=12):                                                                                                                                                          
    .   super(TransformerDecoder, self).__init__()                                                                                                                                                                 
    .   self.transformer = nn.Transformer(nhead=nhead, num_encoder_layers=num_encoder_layers)                                                                                                                      
                                                                                                                                                                                                                   
    def forward(self, trg: torch.Tensor, memory: torch.Tensor):                                                                                                                                                    
    .   trg_seq_length, N , E = trg.size()                                                                                                                                                                         
    .   trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_length)                                                                                                                                
    .   out = self.transformer.decoder(trg, memory, tgt_mask=trg_mask)                                                                                                                                             
    .   #out = self.transformer.decoder(trg, memory)                                                                                                                                                               
    .   return out                                                                                                                                                                                                 
                                                                                                                                                                                                                   
transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)                                                                                                                                                
src = torch.rand((10, 32, 512))                                                                                                                                                                                    
tgt = torch.rand((20, 32, 512))                                                                                                                                                                                    
out = transformer_model(src, tgt)                                                                                                                                                                                  
                                                                                                                                                                                                                   
decoder = TransformerDecoder(nhead=16, num_encoder_layers=12)                                                                                                                                                      
mem = transformer_model.encoder(src)                                                                                                                                                                               
out = decoder(tgt, mem)                                                                                                                                                                                            
                                                                                                                                                                                                                   
onnx_modelD = torch.onnx.dynamo_export(decoder, tgt, mem)

Above is the crashing version. Calling transformer.decoder without tgt_mask, the export runs.

Am I doing something wrong?

Thanks, Jozef