Here x is a tensor, f is an nn.Module model. Both are on cuda:
x = x.to(torch.device('cuda'))
f = f.to(torch.device('cuda'))
The main loop looks like this:
while(some condition):
x = f(x)
I noticed that pretty quickly training slows down. I believe this feedforward mechanism could be a problem. Do I need to use a different approach, like create an intermediary tensor for x and input it into f?
The problem may be that, by doing that, you are increasing the computational graph (if you never call backward). I would say it depends on how many iterations your loop runs until being broken by the condition. I would have a look on that with a counter. Note that if you feed forward lot of times your graph will be huge, thus, will backpropagate way slower.
If I properly understand you want to run that network until it fits a condition? But backpropagating only for those results which fits the condition?
I don’t know if there is an elegant way to do that.
Only possible fast way would be. This approach would by-pass the graph computation, thus, would be faster.
x_old ...
while(1):
with torch.no_grad():
x = f(x_old)
if condition:
x = f(x_old)
break
else:
x_old=x
Another way would be cutting the computational graph. This would probably use more memory and would be slower as all the graph would stay over there (but wouldn’t affect backprop)
while(1):
x = f(x)
if condition:
break
else:
x = x.detach()
There is a 3rd option that would be deleting the graph but I don’t really know if that’s possible.
Condition is just loop until reaching the predefined iteration number. Apart from that it seems that you did what I thought was right:
x = f(x_old)
old_x = x
The idea is to get the network’s output, compute some loss function on it, backprop into the network and then take this output as the network’s input in the next iteration. Does this make sense?
Backpropagation backpropagates from the variable you call backprop from to leaf variables.
Leaf variables are input nodes to the graph. Intermediate variables keep graph history.
When you call detach you are kind of breaking the graph and creating a new fresh leaf variable so that memory is reused. Else, it stills carry old graph I would say.
And it’s not x=detach(x) but x=x.detach() Detach is a tensor method. I dunno if it exists are a standalone function in torch