Pytorch getting slower with every iteration

I’m using pytorch for some optimization problems and it has been a nice tool. Unfortunately I managed to create a script that produces results that I want, but for some reason it is getting slower with every iteration. I tried to reduce the code to a minimal implementation, but so far I have not managed to pinpoint the exact operation that causes the problem.

Here is the minimized script with timing code included: https://pastebin.com/2pvNUQUg

Could someone please verify the results (problem presents both with CPU and GPU). I have not noticed a change in memory usage. As can be seen, every 100 iteration takes around 0.5% more time to complete:

$ python3 pytorch_slowdown.py
3.5418 +254.2%
3.5648 +0.7%
3.5748 +0.3%
3.5940 +0.5%
3.6103 +0.5%
3.6276 +0.5%
3.6512 +0.7%
3.6621 +0.3%
3.6814 +0.5%

So when I clock all of the lines (https://pastebin.com/Q1nXMStH), it shows that the slowdown happens in the loss.backward()-function. I guess one way for that function would slow down is if the graph “grows” with each iteration. How would I see the computation graph that is calculated there? Any idea how to fix that code? Thanks!

loss <- c <- b ------\
          <- z <- b  - >x
               <- z_prev <- ....

So your loss depends on all z's which depends on all bs. Then each new iteration adds another (z, b) pair to the graph. Hence the slow down.

To fix your script depends on where you want the gradient to stop flowing back. E.g., if you want it to stop going to z_prev (and therefore all previous z's and b's), do instead z = updateZ(z.detach(), b.view(N, -1 , xdim[1]))

Thank you Simon! Below is the fixed version of the minimal (still reduced from the original post) problem. Your fix also worked on my original code, so all in all this is resolved.

https://pastebin.com/fFu4fkAY