Loops inside a model slows gradient calculation dramatically

I’m quite new to pytorch and this is my first time posting here so if this has already been addressed or is obvious then I apologize in advance but here is the problem I’m having:

I have a neural network model that trains very quickly when vectorized, here is the vectorized form:
yatomic=((torch.sigmoid(x.mm(torch.t(w1))+torch.t(b1.mm(M1)))-0.5).mm(torch.t(w2))+torch.t(b2.mm(M1)))
ypred=torch.mv(R,(yatomic[:,0]))
loss=(torch.mul((y[:,0]-ypred),Ninv[:,0])**2.0).sum() . #MSE of model

This trains in about 5 minutes and converges very nicely.

However if I instead compute the mean square error in the loss term using for loops (as opposed to the R matrix above which is an alternative to the loops) then the gradient calculation takes around 25X longer, for example the same model but with loops

yatomic=(torch.sigmoid(x.mm(torch.t(w1))+torch.t(b1.mm(M1)))-0.5).mm(torch.t(w2))+torch.t(b2.mm(M1))
k1=0; k2=0; S1=0.0
for i in range(0,len(sample)): #sum over atomic structures
i1=sample[i] #id dictionary
S=0.0
for j in atoms[i1]: #sum over atoms in structure
S=S+yatomic[k1,0]
k1=k1+1
S1=S1+((S-y[i,0])*Ninv[i,0])**2.0 #compare with DFT energy
loss=S1

The model still works and outputs the correct values but it is just much slower, am I doing something wrong/stupid here that is slowing everything down, or is pytorch just much faster at calculating the gradient for matrix-type operations and therefore the slowing is unavoidable when using the loops.

I know that the obvious answer is “why are you using the loops just use the vectorized form” but the neural network is just the first part of the “full” model. I need to use the output of the neural network as an input to an analytical physical model and the model is very difficult to vectorize and at least at present requires some loops and can’t be expressed as a series of matrix-vector operations.

Also I should note that the loops are not the bottleneck, they completes almost instantaneously, it is the gradient calculation that is slowing things down.

Any help would be appreciated!

Hi,

The problem with for loops is two fold:

  • Running the python code for the for loops can be slow and running very small ops is usually less efficient than running big ones (this slows down both the forward and backward pass).
  • All the indexing ops are not free.
  • The number of operations they add to the computational graph is huge and so any operation with it becomes very expensive.

Hi, I’m confused that “for loops slows down backward pass” as I think bp speed is only conditioned on the constructed computational graph, therefore has nothing to do with the forward implementation

But the computational graph match the forward pass. So if you do a lot of ops in the forward, the graph will be large and the backward slower.

Ah, I see. Take matrix multiplication as an example, each loop would create a new node in the computational graph, while torch.matmul would only create an op?