All!
First of all, I’m new to this forum so excuse me if I’m not addressing the right audience.
I’m working on getting BERT forward and backprop traces from PyTorch jit and I ran into a problem: I seem to have encountered a case where the tensor dimensions don’t line up with my expectations. I’ll attach the full graph dump, but it’s unwieldy, so let me summarize:
The forward graph has the following relevant lines:
%x.8 : Float(1, 3, 10) = aten::add(%output.6, %bias, %28)
%self_size.3 : int[] = aten::size(%x.8)
%57 : int = prim::Constant[value=-1](), scope: BERT/SublayerConnection
%114 : int[] = prim::ListConstruct(%57), scope: BERT/SublayerConnection/LayerNorm[norm]
%mean : Float(1, 3, 1) = aten::mean(%x.8, %114, %44), scope: BERT/SublayerConnection/LayerNorm[norm]
%307 : int[] = aten::size(%mean)
The backprop graph has three inputs that are important here, with the following connections:
%76 : int[], <== %114 (input[76] connected to output[40])
%self_size.2 : int[], <== %self_size.3 (input[106] connected to output[70])
%108 : int[], <== %307 (input[108] connected to output[72])
And these values propagate through the backprop graph as follows:
%130 : int = prim::Constant[value=0](), scope: BERT/BERTEmbedding[embedding]/TokenEmbedding[token]
%300 : int = aten::select(%76, %130)
%225 : Tensor, %226 : Tensor = prim::GradOf[name="aten::sub"](%224)
block0():
%227 : Tensor = aten::_grad_sum_to_size(%224, %self_size.2)
%228 : Tensor = aten::neg(%224)
%229 : Tensor = aten::mul(%228, %134), scope: BERT/SublayerConnection/LayerNorm[norm]
%230 : Tensor = aten::_grad_sum_to_size(%229, %108), scope: BERT/SublayerConnection/LayerNorm[norm]
-> (%227, %230)
%301 : Tensor = aten::unsqueeze(%226, %300)
%128 : bool = prim::Constant[value=0](), scope: BERT/BERTEmbedding[embedding]/TokenEmbedding[token]
%302 : Tensor = aten::expand(%301, %self_size.2, %128)
So, if I read this right:
- %128 is just a constant ‘false’, saying that the expand is an explicit one.
- The second parameter (%self_size.2) is the list [1,3,10]
- The first parameter (%301) is the result of an unsqueeze(%226, -1), where %226 is _grad_sum_to_size(%229, %108). %108 in turn is [1,3,1].
Winding it all forward again:
- %226 will have the shape of %108, that is [1,3,1]
- %301 will have one extra dimension added to it to the end by the unsqueeze, so it becomes [1,3,1,1]
- Finally in the last line, it seems we try to expand a Tensor of the shape [1,3,1,1] by an expansion list of [1,3,10].
This doesn’t seem right, in fact it blows up if I try to do this in Python.
So, can you help me figure out where my logic is incorrect? What should I look at, what do I misunderstand about these operations?
The full graphs and their connectivity vectors are here: http://www.modularcircuits.com/download/bert.log
Thanks,
Andras