Hi folks! I am getting going with PyTorch. I’m building a tool that wraps HuggingFace models in a custom WrappedModel
so I can trace their execution using torch.fx.symbolic_trace
. The goal is to analyze the traced graph and detect certain ops like float32
usage.
To do this, I:
- Wrap the model in a subclass of
torch.nn.Module
. - Run a
forward()
pass with dummyinput_ids
. - Call
symbolic_trace(wrapped_model)
or fall back totorch.jit.trace()
.
What’s going wrong:
I consistently see:
Forward pass failed in WrappedModel — slice indices must be integers or None or have an index method
And ultimately:
Rule run failed: ‘method’ object is not iterable
Likely problematic code:
class WrappedModel(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids):
try:
batch_size = input_ids.size(0)
seq_len = input_ids.size(1)
# This line fails during symbolic tracing
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.int64)
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
except Exception as e:
logging.warning(f"TRACE ERROR inside wrapped forward: {e}")
return torch.zeros(1, 1)
if hasattr(output, "last_hidden_state"):
return output.last_hidden_state
elif hasattr(output, "logits"):
return output.logits
return output
What I have already tried:
- Using
input_ids.size(0)
instead ofinput_ids.shape[0]
- Making sure the dummy input has fixed dimensions:
torch.randint(0, 1000, (1, 10))
- Hardcoding the mask shape (e.g.,
torch.ones((1, 10), dtype=torch.int64)
) - Falling back to
torch.jit.trace
— same error during forward - Switching between
BertModel
andBertForSequenceClassification
What (I think) I am asking for:
How do I make torch.ones(...)
work inside a traced wrapper model during symbolic_trace()
?
Thank you in advance for any guidance.