Torch.fx symbolic_trace failing with TypeError on torch.ones: “slice indices must be integers or None”

Hi folks! I am getting going with PyTorch. I’m building a tool that wraps HuggingFace models in a custom WrappedModel so I can trace their execution using torch.fx.symbolic_trace. The goal is to analyze the traced graph and detect certain ops like float32 usage.

To do this, I:

  • Wrap the model in a subclass of torch.nn.Module.
  • Run a forward() pass with dummy input_ids.
  • Call symbolic_trace(wrapped_model) or fall back to torch.jit.trace().

What’s going wrong:

I consistently see:

Forward pass failed in WrappedModel — slice indices must be integers or None or have an index method

And ultimately:

Rule run failed: ‘method’ object is not iterable

Likely problematic code:

class WrappedModel(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, input_ids):
        try:
            batch_size = input_ids.size(0)
            seq_len = input_ids.size(1)

            # This line fails during symbolic tracing
            attention_mask = torch.ones((batch_size, seq_len), dtype=torch.int64)

            output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        except Exception as e:
            logging.warning(f"TRACE ERROR inside wrapped forward: {e}")
            return torch.zeros(1, 1)

        if hasattr(output, "last_hidden_state"):
            return output.last_hidden_state
        elif hasattr(output, "logits"):
            return output.logits
        return output

What I have already tried:

  • Using input_ids.size(0) instead of input_ids.shape[0]
  • Making sure the dummy input has fixed dimensions: torch.randint(0, 1000, (1, 10))
  • Hardcoding the mask shape (e.g., torch.ones((1, 10), dtype=torch.int64))
  • Falling back to torch.jit.trace — same error during forward
  • Switching between BertModel and BertForSequenceClassification

What (I think) I am asking for:

How do I make torch.ones(...) work inside a traced wrapper model during symbolic_trace()?

Thank you in advance for any guidance.