How to avoid torch.reshape with constant shape is optimized to 1 reshape by torch.onnx.export?

if attention_mask is not None:
    if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
        raise ValueError(
            f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
        )
    separate_mask = [torch.reshape(attention_mask, [-1, 1, 256, 4096]) for _ in range(self.num_heads)]
    attn_weights = [aw + mask for aw, mask in zip(attn_weights, separate_mask)]

When I use torch.onnx.export, the reshape ops in all decoder layers is optimized to only 1 reshape. However, this is not the behavior I want.
The target shape of reshape cannot be written as a variable shape for some reason, but has to be a constant shape.
I have tried do_constant_folding=False for torch.onnx.export, but it didn’t work.

Version information

  • Python 3.10
  • Pytorch 2.1.2
  • onnx 1.14.1