I am working on something similar to L2L GDGD. I have a very small network which is applied element-wise over a tensor, the results are stacked and everything is returned.
Because this can be parallelised, I tried to spawn Processes using torch.multiprocessing as a first attempt. However I have noticed that transmitting the result tensors via a queue or pipe wipes out the grad_fn attribute of the tensor. But, if I were to pass a tensor in via Process(args=()) then grad_fn is preserved. A minimal working gist of a slightly different problem that shows the issue is attached below.
If this behaviour is intended, what are my options for speeding my model up? I only have a single GPU, and the network itself is around 120 params anyway, so I don’t think DataParallel can help all that much.
I have the exact same problem: I’m using processes and gathering tensor results using a queue. When calling backward at the end of the computation, the gradient with respect to the variables used in each process (which exist outside the processes) is None, meaning that pytorch did not see these variables are involved in the gradient computation.
On my side, I noticed that I can compute the gradient in each separate process and return it at the same time as I return the result of the forward, but I guess it would be nice to be able to use .backward() after all the computations are done and in the main thread.
The solution I would suggest is exactly what @Maxime_Louis said - when returning tensors from subgraphs, return the gradients too. There’s no way to make autograd work across processes without terrible hacks and a performance penalty - they don’t share the same address space.
If your network is so tiny I’d just go with using the CPU. Also, I would try to write batched code for the network so you don’t need to paralellize over elements, but take the input as a single large batch of independent values and return a batch of results. If you don’t see how to do that, try going with threads, because that will be visible to autograd.