BERT torch graph tensor shape issues


First of all, I’m new to this forum so excuse me if I’m not addressing the right audience.

I’m working on getting BERT forward and backprop traces from PyTorch jit and I ran into a problem: I seem to have encountered a case where the tensor dimensions don’t line up with my expectations. I’ll attach the full graph dump, but it’s unwieldy, so let me summarize:

The forward graph has the following relevant lines:

  %x.8 : Float(1, 3, 10) = aten::add(%output.6, %bias, %28)
  %self_size.3 : int[] = aten::size(%x.8)
  %57 : int = prim::Constant[value=-1](), scope: BERT/SublayerConnection
  %114 : int[] = prim::ListConstruct(%57), scope: BERT/SublayerConnection/LayerNorm[norm]
  %mean : Float(1, 3, 1) = aten::mean(%x.8, %114, %44), scope: BERT/SublayerConnection/LayerNorm[norm]
  %307 : int[] = aten::size(%mean)

The backprop graph has three inputs that are important here, with the following connections:

   %76 : int[],           <== %114          (input[76]  connected to output[40])
   %self_size.2 : int[],  <== %self_size.3  (input[106] connected to output[70]) 
   %108 : int[],          <== %307          (input[108] connected to output[72])

And these values propagate through the backprop graph as follows:

  %130 : int = prim::Constant[value=0](), scope: BERT/BERTEmbedding[embedding]/TokenEmbedding[token]
  %300 : int = aten::select(%76, %130)
  %225 : Tensor, %226 : Tensor = prim::GradOf[name="aten::sub"](%224)
      %227 : Tensor = aten::_grad_sum_to_size(%224, %self_size.2)
      %228 : Tensor = aten::neg(%224)
      %229 : Tensor = aten::mul(%228, %134), scope: BERT/SublayerConnection/LayerNorm[norm]
      %230 : Tensor = aten::_grad_sum_to_size(%229, %108), scope: BERT/SublayerConnection/LayerNorm[norm]
      -> (%227, %230)
  %301 : Tensor = aten::unsqueeze(%226, %300)
  %128 : bool = prim::Constant[value=0](), scope: BERT/BERTEmbedding[embedding]/TokenEmbedding[token]
  %302 : Tensor = aten::expand(%301, %self_size.2, %128)

So, if I read this right:

  • %128 is just a constant ‘false’, saying that the expand is an explicit one.
  • The second parameter (%self_size.2) is the list [1,3,10]
  • The first parameter (%301) is the result of an unsqueeze(%226, -1), where %226 is _grad_sum_to_size(%229, %108). %108 in turn is [1,3,1].

Winding it all forward again:

  • %226 will have the shape of %108, that is [1,3,1]
  • %301 will have one extra dimension added to it to the end by the unsqueeze, so it becomes [1,3,1,1]
  • Finally in the last line, it seems we try to expand a Tensor of the shape [1,3,1,1] by an expansion list of [1,3,10].

This doesn’t seem right, in fact it blows up if I try to do this in Python.

So, can you help me figure out where my logic is incorrect? What should I look at, what do I misunderstand about these operations?

The full graphs and their connectivity vectors are here:


Hi Andras,

Thanks for asking here. I don’t know how exactly the program you have is wrong. But if it blows up in Python already, you can use pdb to debug on your python (non-jit) program, and see what’s wrong in the intermediate steps. Put something in your program __import__('pdb').set_trace() and you can see where it becomes wrong, so that we might have more context.

Thanks for the reply.

What I meant by ‘blowing up’ is that if I manually call ‘expand’ with a tensor of dimensions [1,3,1,1] and an expansion list of [1,3,10], I get a runtime error stating that the expansion list is shorter than the number of dimensions of the input tensor. Which of course makes complete sense, but for the life of me I can’t figure out how that wouldn’t be the case in this particular graph.

My original goal is not to execute the graph, merely to extract type information from it for each of the values. and that’s where of course I run into trouble with this particular subsection of it.

Thanks again,