How to avoid for loops in pytorch?

I’m working on a task where the input is a sequence of images (5 in my case) and the output should be a set of sentences, one per image. I’m using an encoder-decoder architecture where I would like to use the sentences I already generated as input for generating the next one but this requires me to loop for each instance (for loop) inside the forward pass.
As already discussed here, each call inside the loop duplicates the computation graph therefore increasing the memory utilization.
Is there any way to get around the need to use the loop? Or alternatively, is there an explicit way to avoid duplicating the graph?

Thanks

you could use .detach() to avoid duplicating the graph

1 Like