Can't get dynamic shape with torch.export.export_for_training

Hi, I want to implement quantization aware traing in YOLOV5, but I can’t get dynamic shape with height and width to input in torch.export.export_for_training function.

This is my reproducible code, only part of yolov5 included.

# Ultralytics YOLOv5 🚀, AGPL-3.0 license
"""
YOLO-specific modules.

Usage:
    $ python models/yolo.py --cfg yolov5s.yaml
"""

from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import (
  prepare_qat_pt2e,
  convert_pt2e,
)
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
  XNNPACKQuantizer,
  get_symmetric_quantization_config,
)
from torch.export import Dim


import torch
import torch.nn as nn


def autopad(k, p=None, d=1):
    """
    Pads kernel to 'same' output shape, adjusting for optional dilation; returns padding size.

    `k`: kernel, `p`: padding, `d`: dilation.
    """
    if d > 1:
        k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]  # actual kernel-size
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p

class C3(nn.Module):
    """Implements a CSP Bottleneck module with three convolutions for enhanced feature extraction in neural networks."""

    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
        """Initializes C3 module with options for channel count, bottleneck repetition, shortcut usage, group
        convolutions, and expansion.
        """
        super().__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c1, c_, 1, 1)
        self.cv3 = Conv(2 * c_, c2, 1)  # optional act=FReLU(c2)
        self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))

    def forward(self, x):
        """Performs forward propagation using concatenated outputs from two convolutions and a Bottleneck sequence."""
        return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))

class Conv(nn.Module):
    """Applies a convolution, batch normalization, and activation function to an input tensor in a neural network."""

    default_act = nn.SiLU()  # default activation

    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
        """Initializes a standard convolution layer with optional batch normalization and activation."""
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()

    def forward(self, x):
        """Applies a convolution followed by batch normalization and an activation function to the input tensor `x`."""
        return self.act(self.bn(self.conv(x)))

    def forward_fuse(self, x):
        """Applies a fused convolution and activation function to the input tensor `x`."""
        return self.act(self.conv(x))

class Bottleneck(nn.Module):
    """A bottleneck layer with optional shortcut and group convolution for efficient feature extraction."""

    def __init__(self, c1, c2, shortcut=True, g=1, e=0.5):
        """Initializes a standard bottleneck layer with optional shortcut and group convolution, supporting channel
        expansion.
        """
        super().__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c_, c2, 3, 1, g=g)
        self.add = shortcut and c1 == c2

    def forward(self, x):
        """Processes input through two convolutions, optionally adds shortcut if channel dimensions match; input is a
        tensor.
        """
        return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))


class YOLOv5(nn.Module):
    def __init__(self):
        super(YOLOv5, self).__init__()
        self.Conv1 = Conv(3, 16, 6, 2, 2)
        self.Conv2 = Conv(16, 32, 3, 2)
        self.C3_1 = C3(32, 32, 1)


    def forward(self, x):
        x = self.Conv1(x)  
        x = self.Conv2(x)     
        x = self.C3_1(x)     

        return x  
    
model = YOLOv5()

example_inputs = (torch.rand(4, 3, 1024, 1024),)

dynamic_shapes = (
  {"x":{ 
        2: torch.export.Dim("dim3"),
        3: torch.export.Dim("dim4")}}
)

exported_model = torch.export.export_for_training(model, example_inputs, dynamic_shapes=dynamic_shapes).module()

error as follows:

E1119 02:58:42.538000 56708 site-packages/torch/_guards.py:283] [0/0] Error while creating guard:
E1119 02:58:42.538000 56708 site-packages/torch/_guards.py:283] [0/0] Name: ''
E1119 02:58:42.538000 56708 site-packages/torch/_guards.py:283] [0/0]     Source: shape_env
E1119 02:58:42.538000 56708 site-packages/torch/_guards.py:283] [0/0]     Create Function: SHAPE_ENV
E1119 02:58:42.538000 56708 site-packages/torch/_guards.py:283] [0/0]     Guard Types: None
E1119 02:58:42.538000 56708 site-packages/torch/_guards.py:283] [0/0]     Code List: None
E1119 02:58:42.538000 56708 site-packages/torch/_guards.py:283] [0/0]     Object Weakref: None
E1119 02:58:42.538000 56708 site-packages/torch/_guards.py:283] [0/0]     Guarded Class Weakref: None
E1119 02:58:42.538000 56708 site-packages/torch/_guards.py:283] [0/0] Traceback (most recent call last):
E1119 02:58:42.538000 56708 site-packages/torch/_guards.py:283] [0/0]   File "/opt/conda/lib/python3.10/site-packages/torch/_guards.py", line 281, in create
E1119 02:58:42.538000 56708 site-packages/torch/_guards.py:283] [0/0]     return self.create_fn(builder, self)
E1119 02:58:42.538000 56708 site-packages/torch/_guards.py:283] [0/0]   File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/guards.py", line 1836, in SHAPE_ENV
E1119 02:58:42.538000 56708 site-packages/torch/_guards.py:283] [0/0]     guards = output_graph.shape_env.produce_guards(
E1119 02:58:42.538000 56708 site-packages/torch/_guards.py:283] [0/0]   File "/opt/conda/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 4178, in produce_guards
E1119 02:58:42.538000 56708 site-packages/torch/_guards.py:283] [0/0]     raise ConstraintViolationError(
E1119 02:58:42.538000 56708 site-packages/torch/_guards.py:283] [0/0] torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (dim3, dim4)! For more information, run with TORCH_LOGS="+dynamic".
E1119 02:58:42.538000 56708 site-packages/torch/_guards.py:283] [0/0]   - Not all values of dim4 = L['x'].size()[3] in the specified range satisfy the generated guard Ne((L['x'].size()[3]//2), 1).
E1119 02:58:42.538000 56708 site-packages/torch/_guards.py:283] [0/0]   - Not all values of dim3 = L['x'].size()[2] in the specified range satisfy the generated guard Ne((L['x'].size()[2]//2), 1).
E1119 02:58:42.538000 56708 site-packages/torch/_guards.py:283] [0/0]   - Not all values of dim4 = L['x'].size()[3] in the specified range satisfy the generated guard Ne((((L['x'].size()[3]//2) - 1)//2) + 1, 1).
E1119 02:58:42.538000 56708 site-packages/torch/_guards.py:283] [0/0]   - Not all values of dim3 = L['x'].size()[2] in the specified range satisfy the generated guard Ne((((L['x'].size()[2]//2) - 1)//2) + 1, 1).
E1119 02:58:42.539000 56708 site-packages/torch/_guards.py:285] [0/0] Created at:
E1119 02:58:42.539000 56708 site-packages/torch/_guards.py:285] [0/0]   File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 615, in transform
E1119 02:58:42.539000 56708 site-packages/torch/_guards.py:285] [0/0]     tracer = InstructionTranslator(
E1119 02:58:42.539000 56708 site-packages/torch/_guards.py:285] [0/0]   File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2670, in __init__
E1119 02:58:42.539000 56708 site-packages/torch/_guards.py:285] [0/0]     output=OutputGraph(
E1119 02:58:42.539000 56708 site-packages/torch/_guards.py:285] [0/0]   File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 317, in __init__
E1119 02:58:42.539000 56708 site-packages/torch/_guards.py:285] [0/0]     self.init_ambient_guards()
E1119 02:58:42.539000 56708 site-packages/torch/_guards.py:285] [0/0]   File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 463, in init_ambient_guards
E1119 02:58:42.539000 56708 site-packages/torch/_guards.py:285] [0/0]     self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/export/_trace.py", line 560, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1477, in inner
    raise constraint_violation_error
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1432, in inner
    result_traced = opt_f(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1269, in __call__
    return self._torchdynamo_orig_callable(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 526, in __call__
    return _compile(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 924, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 666, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/opt/conda/lib/python3.10/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
    return function(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 796, in _compile_inner
    check_fn = CheckFunctionManager(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/guards.py", line 2261, in __init__
    guard.create(builder)
  File "/opt/conda/lib/python3.10/site-packages/torch/_guards.py", line 281, in create
    return self.create_fn(builder, self)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/guards.py", line 1836, in SHAPE_ENV
    guards = output_graph.shape_env.produce_guards(
  File "/opt/conda/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 4178, in produce_guards
    raise ConstraintViolationError(
torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (dim3, dim4)! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of dim4 = L['x'].size()[3] in the specified range satisfy the generated guard Ne((L['x'].size()[3]//2), 1).
  - Not all values of dim3 = L['x'].size()[2] in the specified range satisfy the generated guard Ne((L['x'].size()[2]//2), 1).
  - Not all values of dim4 = L['x'].size()[3] in the specified range satisfy the generated guard Ne((((L['x'].size()[3]//2) - 1)//2) + 1, 1).
  - Not all values of dim3 = L['x'].size()[2] in the specified range satisfy the generated guard Ne((((L['x'].size()[2]//2) - 1)//2) + 1, 1).


During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/root/yolov5/models/./reproced_test_yolo.py", line 120, in <module>
    exported_model = torch.export.export_for_training(model, example_inputs, dynamic_shapes=dynamic_shapes).module()
  File "/opt/conda/lib/python3.10/site-packages/torch/export/__init__.py", line 154, in export_for_training
    return _export_for_training(
  File "/opt/conda/lib/python3.10/site-packages/torch/export/_trace.py", line 1017, in wrapper
    raise e
  File "/opt/conda/lib/python3.10/site-packages/torch/export/_trace.py", line 990, in wrapper
    ep = fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/export/exported_program.py", line 114, in wrapper
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/export/_trace.py", line 1746, in _export_for_training
    export_artifact = export_func(  # type: ignore[operator]
  File "/opt/conda/lib/python3.10/site-packages/torch/export/_trace.py", line 1252, in _strict_export_lower_to_aten_ir
    gm_torch_level = _export_to_torch_ir(
  File "/opt/conda/lib/python3.10/site-packages/torch/export/_trace.py", line 576, in _export_to_torch_ir
    raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e))  # noqa: B904
torch._dynamo.exc.UserError: Constraints violated (dim3, dim4)! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of dim4 = L['x'].size()[3] in the specified range satisfy the generated guard Ne((L['x'].size()[3]//2), 1).
  - Not all values of dim3 = L['x'].size()[2] in the specified range satisfy the generated guard Ne((L['x'].size()[2]//2), 1).
  - Not all values of dim4 = L['x'].size()[3] in the specified range satisfy the generated guard Ne((((L['x'].size()[3]//2) - 1)//2) + 1, 1).
  - Not all values of dim3 = L['x'].size()[2] in the specified range satisfy the generated guard Ne((((L['x'].size()[2]//2) - 1)//2) + 1, 1).

In YOLOV5 training, the height and width of input is dynamic, so I want got dynamic shape for quantization aware trainig, but is seems to does not satisty some constrain, how to fix it?

My Pytorch version is 2.5.

can you try:

dynamic_shapes = (
  {"x":{ 
        2: torch.export.Dim.AUTO,
        3: torch.export.Dim.AUTO}}
)