I’m trying to export a model which uses a tensor of tensor shapes and I can’t find a way to hint that such a tensor will be constant. Here’s a minimum reproducible example:
import torch
import torch.nn as nn
import typing
import torch._dynamo as dynamo
import logging
class NetMRE(nn.Module):
def forward(self, x: typing.List[torch.Tensor]):
spatial_shapes: typing.List[typing.Tuple[int, int]] = []
for xi in x:
spatial_shapes.append(xi.shape[2:])
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long)
reference_points_list = []
for H, W in spatial_shapes:
lin = torch.linspace(0.5, H - 0.5, H, dtype=torch.float32) # --- breaks here
reference_points_list.append(lin)
return reference_points_list
example_kwargs = {
"x": [
torch.rand(1, 3, 64, 64),
torch.rand(1, 3, 32, 32),
],
}
exported_program: torch.export.ExportedProgram = export(
NetMRE(), (), kwargs=example_kwargs, strict=True,
)
Which throws the following error:
GuardOnDataDependentSymNode: Could not extract specialized integer from data-dependent expression u0 (unhinted: u0). (Size-like symbols: none)
Potential framework code culprit (scroll up for full backtrace):
File ".venv/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1903, in run_node
return node.target(*args, **kwargs)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing
User Stack (most recent call last):
(snipped, see stack below for prefix)
File "/var/folders/rk/fqb5s6dn6sl66ntxp1ssybz80000gn/T/ipykernel_9748/14400332.py", line 10, in forward
lin = torch.linspace(0.5, H - 0.5, H, dtype=torch.float32) # --- breaks here
I’ve found that a simple fix is to not recast the list of shapes to a tensor of shapes, but that would require changing code that assumes its a tensor. If there’s a way to hint the compiler to specialize the tensor values that would be amazing.
Thanks in advance.