Using torch.fx.symbolic_trace with view(*shape)


I’m trying to use torch.fx.symbolic_trace to trace the bert.SelfAttention module. However, I’m running into the following error when torch.Tensor.view is called as x = x.view(*new_shape):

torch.fx.proxy.TraceError: Proxy object cannot be iterated. This can be attempted when used in a for loop or as a *args or **kwargs function argument.

Possibly caused due to *new_shape argument to view? I’ve been trying to write a custom Tracer class with a new iter method, but I’m not able to resolve the issue. Is there a workaround to fixing this without modifying the actual self-attention code?

Full code:

from torch.fx import symbolic_trace
from transformers.models.bert import modeling_bert as bert
from transformers.models.bert import BertConfig

bert_config = BertConfig(hidden_size=10, num_attention_heads=5)
attention_module = bert.BertSelfAttention(bert_config)
graph = symbolic_trace(attention_module).graph

Sometimes, these models can be involved, especially for a library with the genesis and target audience as transformers.

Chance has it that I released an initial commit of the toroidal transformers that has an attention module that works well with symbolic tracing:

m = toroidal.models.Attention(786, 3)
gr = torch.fx.symbolic_trace(m)

or in a BERT model (the multilingual from the test directory):

gr = torch.fx.symbolic_trace(model_bert.blocks[0].attn)



  (qkv): Linear(in_features=768, out_features=2304, bias=True)
  (dropout_attn): Dropout(p=0.1, inplace=False)
  (proj): Linear(in_features=768, out_features=768, bias=True)
  (dropout_out): Dropout(p=0.1, inplace=False)

def forward(self, x):
    getattr_1 = x.shape
    getitem = getattr_1[0]
    getitem_1 = getattr_1[1]
    getitem_2 = getattr_1[2];  getattr_1 = None
    qkv = self.qkv(x);  x = None
    view = qkv.view(getitem, getitem_1, 3, 12, -1);  qkv = None
    unbind = view.unbind(dim = 2);  view = None
    getitem_3 = unbind[0]
    getitem_4 = unbind[1]
    getitem_5 = unbind[2];  unbind = None
    einsum = torch.functional.einsum('bthc,bshc->bhts', getitem_3, getitem_4);  getitem_3 = getitem_4 = None
    mul = einsum * 0.125;  einsum = None
    softmax = mul.softmax(-1);  mul = None
    dropout_attn = self.dropout_attn(softmax);  softmax = None
    einsum_1 = torch.functional.einsum('bhts,bshc->bthc', dropout_attn, getitem_5);  dropout_attn = getitem_5 = None
    reshape = einsum_1.reshape(getitem, getitem_1, getitem_2);  einsum_1 = getitem = getitem_1 = getitem_2 = None
    proj = self.proj(reshape);  reshape = None
    dropout_out = self.dropout_out(proj);  proj = None
    return dropout_out

and I guess this is the disadvantage for manipulating the matmuls: I used einsum instead. :slight_smile:

If GPT works for you, you could check out minGPT, I’d imagine that A. Karpathy’s code might work more easily with the tracer.

Best regards


1 Like

Thanks for the quick response! I’ll look into this!