How to hint torch.export that a tensor is a constant?

I’m trying to export a model which uses a tensor of tensor shapes and I can’t find a way to hint that such a tensor will be constant. Here’s a minimum reproducible example:

import torch
import torch.nn as nn
import typing
import torch._dynamo as dynamo
import logging

class NetMRE(nn.Module):
    def forward(self, x: typing.List[torch.Tensor]):
        spatial_shapes: typing.List[typing.Tuple[int, int]] = []
        for xi in x:
            spatial_shapes.append(xi.shape[2:])

        spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long)
        reference_points_list = []
        for H, W in spatial_shapes:
            lin = torch.linspace(0.5, H - 0.5, H, dtype=torch.float32) # --- breaks here
            reference_points_list.append(lin)
        return reference_points_list

example_kwargs = {
    "x": [
        torch.rand(1, 3, 64, 64),
        torch.rand(1, 3, 32, 32),
    ],
}
exported_program: torch.export.ExportedProgram = export(
    NetMRE(), (), kwargs=example_kwargs, strict=True,
)

Which throws the following error:

GuardOnDataDependentSymNode: Could not extract specialized integer from data-dependent expression u0 (unhinted: u0).  (Size-like symbols: none)

Potential framework code culprit (scroll up for full backtrace):
  File ".venv/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1903, in run_node
    return node.target(*args, **kwargs)

For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

User Stack (most recent call last):
  (snipped, see stack below for prefix)
  File "/var/folders/rk/fqb5s6dn6sl66ntxp1ssybz80000gn/T/ipykernel_9748/14400332.py", line 10, in forward
    lin = torch.linspace(0.5, H - 0.5, H, dtype=torch.float32) # --- breaks here

I’ve found that a simple fix is to not recast the list of shapes to a tensor of shapes, but that would require changing code that assumes its a tensor. If there’s a way to hint the compiler to specialize the tensor values that would be amazing.

Thanks in advance.

This is most likely a very-hacky-and-probably-wrong way, but after debugging for some time, I’ve found out that it’s possible to avoid the conversion to faketensor and that sort of fixes the issue. Or at least it generates the desired ExportedProgram. However because of the use of the use of non-trivial Python (context managers or torch._C._set_dispatch_mode) it’s only possible to use strict=False.

class NetC2v2(nn.Module):
    def forward(self, x: typing.List[torch.Tensor]):
        spatial_shapes: typing.List[typing.Tuple[int, int]] = []
        for i, xi in enumerate(x):
            spatial_shapes.append(tuple(xi.shape[2:]))
        
        with unset_fake_temporarily():
            with disable_proxy_modes_tracing():
                spatial_shapes_tensor = torch.tensor(spatial_shapes, dtype=torch.long)
        reference_points = self.get_reference_points(spatial_shapes_tensor)
        return reference_points

    @staticmethod
    def get_reference_points(
        spatial_shapes: torch.Tensor,
    ): 
        reference_points_list = []
        for i in range(len(spatial_shapes)):
            with disable_proxy_modes_tracing():
                with unset_fake_temporarily():
                    h = int(spatial_shapes[i][0])
            lin = torch.linspace(0.5, h - 0.5, h, dtype=torch.float32)
            reference_points_list.append(lin)
        return reference_points_list

Resulting ExportedProgram:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x_0: "f32[1, 3, 150, 100]", x_1: "f32[1, 3, 75, 50]", x_2: "f32[1, 3, 37, 25]", x_3: "f32[1, 3, 19, 13]"):
            linspace: "f32[150]" = torch.ops.aten.linspace.default(0.5, 149.5, 150, device = device(type='cpu'), pin_memory = False)
            linspace_1: "f32[75]" = torch.ops.aten.linspace.default(0.5, 74.5, 75, device = device(type='cpu'), pin_memory = False)
            linspace_2: "f32[37]" = torch.ops.aten.linspace.default(0.5, 36.5, 37, device = device(type='cpu'), pin_memory = False)
            linspace_3: "f32[19]" = torch.ops.aten.linspace.default(0.5, 18.5, 19, device = device(type='cpu'), pin_memory = False)
            return (linspace, linspace_1, linspace_2, linspace_3)

I have no idea what I’m doing, any help is appreciated haha