Dynamic slicing torch.export

I want to do this:

import torch
from torch.export import export


class DynamicShapeSlicing(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.probs = torch.rand((3,))

    def forward(self, x, lengths):
        idx = self.probs.argmax()
        return x[idx, :, : lengths[idx]]


example_args = (torch.randn(3, 10, 8), torch.randint(0, 8, (3,)))
model = DynamicShapeSlicing()


ex = export(model, example_args, strict=False)
ex.module()(*example_args)

This is the error I get:

GuardOnDataDependentSymNode: Could not extract specialized integer from data-dependent expression u0 (unhinted: u0). (Size-like symbols: none)

Potential framework code culprit (scroll up for full backtrace):
File “/home/will/.pyenv/versions/3.10.6/envs/notebook/lib/python3.10/site-packages/torch/_export/non_strict_utils.py”, line 520, in torch_function
return func(*args, **kwargs)

For more information, run with TORCH_LOGS=“dynamic”
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=“u0”
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see Dealing with GuardOnDataDependentSymNode errors - Google Docs

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

The following call raised this error:
File “/tmp/ipykernel_1216314/1925435421.py”, line 16, in forward
return x[idx, :, : lengths[idx]]

What is the proper way to do this that will allow me to export my model?

This is what I ended up having to do. I’m not convinced that it is the best method. Obviously not very clean at all.

class DynamicShapeSlicing(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.probs = torch.rand((3,))

    def forward(self, x, lengths):
        idx = self.probs.argmax().item()
        torch._check(x.size(0) > idx)
        torch._check(0 <= idx)
        torch._check(lengths.size(0) > idx)
        torch._check_is_size(idx)
        length = lengths.narrow(0, idx, 1).item()
        torch._check(length < x.size(2))
        torch._check_is_size(length)
        return x.narrow(0, idx, 1).narrow(2, 0, length)

Can you file a github issue? Sign in to GitHub · GitHub. Relevant folks can discuss whether there is a cleaner way to do this (or if not, whether there’s anything to do to make this code more export-friendly by default)