How to attach a tensor to a particular point in the computation graph?

As stated in the question, I need to attach a tensor to a particular point in the computation graph in Pytorch.

What I’m trying to do is this: while geting outputs from all mini-batches, accumulate them in a list and when one epoch finishes, calculate the mean. Then, I need to calculate loss according to the mean, therefore backpropagation must consider all these operations.

I am able to do that when the training data is not much (without detaching and storing). However, this is not possible when it gets bigger. If I don’t detach output tensors each time, I’m running out of GPU memories and if I detach, I lose the track of output tensors from the computation graph. Looks like this is not possible no matter how many GPUs I have since Pytorch does only use first 4 for storing output tensors if I don’t detach before saving them into a list even if I assign more than 4 GPUs.

Any help is really appreciated.

Thanks.

Have you got any solutions on this?