I’m trying to use
torch.fx.symbolic_trace to trace the
bert.SelfAttention module. However, I’m running into the following error when
torch.Tensor.view is called as
x = x.view(*new_shape):
torch.fx.proxy.TraceError: Proxy object cannot be iterated. This can be attempted when used in a for loop or as a *args or **kwargs function argument.
Possibly caused due to
*new_shape argument to
view? I’ve been trying to write a custom
Tracer class with a new
iter method, but I’m not able to resolve the issue. Is there a workaround to fixing this without modifying the actual self-attention code?
from torch.fx import symbolic_trace from transformers.models.bert import modeling_bert as bert from transformers.models.bert import BertConfig bert_config = BertConfig(hidden_size=10, num_attention_heads=5) attention_module = bert.BertSelfAttention(bert_config) graph = symbolic_trace(attention_module).graph