Training slows down with loop like x=f(x)

Here x is a tensor, f is an nn.Module model. Both are on cuda:

x = x.to(torch.device('cuda'))
f = f.to(torch.device('cuda'))

The main loop looks like this:

while(some condition):
      x = f(x)

I noticed that pretty quickly training slows down. I believe this feedforward mechanism could be a problem. Do I need to use a different approach, like create an intermediary tensor for x and input it into f?

The problem may be that, by doing that, you are increasing the computational graph (if you never call backward). I would say it depends on how many iterations your loop runs until being broken by the condition. I would have a look on that with a counter. Note that if you feed forward lot of times your graph will be huge, thus, will backpropagate way slower.

c=-1
while(some condition)
     c+=1
     x = f(x)

OK then maybe I don’t understand how Pytorch works well enough. When I set
x=f(x)
does the ‘new’ x replace the ‘old’ x in the graph?

Nope, you are creating a junction between both forward passes in a siamese-like fashion. All your runs will have an effect in the backpropagation.

Got it, so how should I do it instead? something like


y=f(x)
x=y

If I properly understand you want to run that network until it fits a condition? But backpropagating only for those results which fits the condition?

I don’t know if there is an elegant way to do that.
Only possible fast way would be. This approach would by-pass the graph computation, thus, would be faster.

x_old ...
while(1):
     with torch.no_grad():
          x = f(x_old)
     if condition:
         x = f(x_old)
         break
     else:
         x_old=x

Another way would be cutting the computational graph. This would probably use more memory and would be slower as all the graph would stay over there (but wouldn’t affect backprop)

while(1):
     x = f(x)
     if condition:
         break
     else:
         x = x.detach()

There is a 3rd option that would be deleting the graph but I don’t really know if that’s possible.

Anyway, this pipeline looks a bit strange to me.

Condition is just loop until reaching the predefined iteration number. Apart from that it seems that you did what I thought was right:

x = f(x_old)
old_x = x

The idea is to get the network’s output, compute some loss function on it, backprop into the network and then take this output as the network’s input in the next iteration. Does this make sense?

Well, it’s not the same at all.
So that would be

while(some condition):
      x = f(x)
     loss = function(x)
     loss.backward()
     ...
     x = x.detach()

Last part will convert x into a leaf variable and should work

Thanks. So x=detach(x) creates a new variable x and adds it to the graph. What happens to the old x then?

Backpropagation backpropagates from the variable you call backprop from to leaf variables.

Leaf variables are input nodes to the graph. Intermediate variables keep graph history.
When you call detach you are kind of breaking the graph and creating a new fresh leaf variable so that memory is reused. Else, it stills carry old graph I would say.

And it’s not x=detach(x) but x=x.detach() :slight_smile: Detach is a tensor method. I dunno if it exists are a standalone function in torch

Yeah sorry my mistake. So x.detach() removes the ‘old’ x from the graph and inserts the new one, right? Other (inner) tensors stay the same.

yep