Dynamic slicing

I’m struggling a bit to make some fairly simple operations (like trim the size of the mask to match varying length sequences in a self-attention layer) play nicely with torch.compile.

One thing I don’t think I have a good mental model of is why dynamic slicing is not supported but masking is.

e.g.

@torch.compile(backend="eager", fullgraph=True)
def f(i: torch.Tensor, x: torch.Tensor):
  return x[0:i]

@torch.compile(backend="eager", fullgraph=True)
def g(i: torch.Tensor, x: torch.Tensor):
  idx = torch.arange(x.shape[0])
  return x[idx < i]

f(torch.tensor(1), torch.tensor([1,2,3]))
g(torch.tensor(1), torch.tensor([1,2,3]))

f will error with “Unsupported: Dynamic slicing on data-dependent value is not supported” (pytorch 2.6) but g will work.

I think I just don’t have a good mental model for why it is easier for torch.compile to support masking than dynamic slicing.

The compiler cannot statically determine how much of the input tensor can be sliced off at compile time. You’re slicing the tensor x based on the value of i which is data-dependent. Masking is not directly changing the shape of the tensor and is instead operating on a boolean mask (idx < i).

Incidentally I was not able to run either function without this.
torch._dynamo.config.capture_dynamic_output_shape_ops = True

Oh yeah forgot I had that set.

But the result whether you use masking or slicing is a dynamic shape … like the effect of these two ops is identical.

Yes, they are equivalent in result. My last reply was slightly inaccurate. Masking like x[idx < i] does produce a dynamic shape too, since the number of True values in the mask depends on i.

The difference is that masking stays fully in tensor ops, so the compiler can trace it. In contrast, slicing with x[0:i] involves Python-level indexing with a tensor, which breaks the graph under fullgraph=True as that will fail if any part of the function is not compatible. If you remove that setting, it would run that in eager mode and continue. You will lose out on potential optimizations though.

I realize the explanation, I think what I’m missing is the mental model of why torch.compile doesn’t support slicing using a tensor (e.g. runtime check on the tensor value and then slice). Like, what makes that harder than supporting masking (which it does).

This doc seems relevant but outdated

Often, the graph breaks are just missing functionality in Dynamo, please submit bugs for these. Sometimes, you can work around these by simplifying your Python code (we still encourage to submit a bug though, we would like to support all of these patterns!) However, there are some particular situations which are likely to result in lots of graph breaks and need some different treatment.

By default, if you have any data-dependent computation, e.g., boolean masking, item() call, nonzero(), etc, this will trigger a graph break. If you are feeling brave, you can get past these problems by setting torch._dynamo.config.capture_scalar_outputs = True and torch._dynamo.config.capture_dynamic_output_shape_ops = True. You will want to read Dealing with GuardOnDataDependentSymNode errors next.

Based on this the setting you applied to torch.compile should support both as they’re both data-dependent computations. Maybe file an issue.