I’m trying to use torch.fx.symbolic_trace to trace the bert.SelfAttention module. However, I’m running into the following error when torch.Tensor.view is called as x = x.view(*new_shape):
torch.fx.proxy.TraceError: Proxy object cannot be iterated. This can be attempted when used in a for loop or as a *args or **kwargs function argument.
Possibly caused due to *new_shape argument to view? I’ve been trying to write a custom Tracer class with a new iter method, but I’m not able to resolve the issue. Is there a workaround to fixing this without modifying the actual self-attention code?
Full code:
from torch.fx import symbolic_trace
from transformers.models.bert import modeling_bert as bert
from transformers.models.bert import BertConfig
bert_config = BertConfig(hidden_size=10, num_attention_heads=5)
attention_module = bert.BertSelfAttention(bert_config)
graph = symbolic_trace(attention_module).graph