ONNX export fails for custom interpolate function with dynamic scale factor in torch.export.export (PyTorch 2.9.0)

Here’s a well-structured GitHub issue for the PyTorch repository based on your problem:


Title: ONNX export fails for custom interpolate function with dynamic scale factor in torch.export.export (PyTorch 2.9.0)

Description

I’m encountering an issue when exporting a model to ONNX using torch.onnx.export with a custom interpolation function that takes a dynamic scale factor. The export fails during the torch.export.export step with a type error in upsample_bicubic2d.

Environment

  • PyTorch Version: 2.9.0+cu128
  • ONNX opset: 18

Steps to Reproduce

  1. Define a custom torch.autograd.Function for interpolation with a dynamic scale factor.
  2. Use this function in a model’s forward method.
  3. Attempt to export the model to ONNX using torch.onnx.export.

Minimal Reproduction Code:

import torch
from torch import nn
from torch.nn.functional import interpolate
import torch.onnx

class NewInterpolate(torch.autograd.Function):
    @staticmethod
    def symbolic(g, input, scales):
        return g.op(
            "Resize",
            input,
            g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)),
            scales,
            coordinate_transformation_mode_s="pytorch_half_pixel",
            cubic_coeff_a_f=-0.75,
            mode_s="cubic",
            nearest_mode_s="floor"
        )

    @staticmethod
    def forward(ctx, input, scales):
        return interpolate(input, scale_factor=scales.tolist()[-2:], mode="bicubic", align_corners=False)

class StrangeSuperResolutionNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=4)
        self.conv2 = nn.Conv2d(64, 32, kernel_size=1, padding=0)
        self.conv3 = nn.Conv2d(32, 3, kernel_size=5, padding=2)
        self.relu = nn.ReLU()

    def forward(self, x, upscale_factor):
        x = NewInterpolate.apply(x, upscale_factor)
        out = self.relu(self.conv1(x))
        out = self.relu(self.conv2(out))
        out = self.conv3(out)
        return out

model = StrangeSuperResolutionNet()
model.eval()
factor = torch.tensor([1, 1, 3, 3], dtype=torch.float32)
x = torch.randn(1, 3, 256, 256)

torch.onnx.export(
    model, 
    (x, factor), 
    "srcnn2.onnx", 
    opset_version=18, 
    input_names=['input', 'scale_factor'],
    output_names=['output']
)

Error Log

[torch.onnx] Obtain model graph for `StrangeSuperResolutionNet` with `torch.export.export(..., strict=False)`... ❌
[torch.onnx] Obtain model graph for `StrangeSuperResolutionNet` with `torch.export.export(..., strict=True)`... ❌
Traceback (most recent call last):
  ...
TypeError: upsample_bicubic2d() received an invalid combination of arguments - got (FakeTensor, NoneType, bool, list), but expected one of:
 * (Tensor input, tuple of ints output_size, bool align_corners, tuple of floats scale_factors)
      didn't match because some of the arguments have invalid types: (FakeTensor, NoneType, bool, list of [SymInt, SymInt])
 * (Tensor input, tuple of ints output_size, bool align_corners, float scales_h = None, float scales_w = None, *, Tensor out = None)

Expected Behavior

The model should export successfully to ONNX with the dynamic scale factor properly handled by the torch.export.export mechanism.

Additional Context

  • The issue appears to be related to type handling during the tracing phase of torch.export.export.
  • The same code might have worked in previous PyTorch versions with the older ONNX exporter.
  • The error occurs when interpolate is called with a list for scale_factor (from scales.tolist()[-2:]) while tracing with fake tensors.

Question

What is the recommended approach in PyTorch 2.9 for exporting models with dynamic resize/scale factors to ONNX, particularly when using the new torch.export.export infrastructure?