Avoiding retain_graph=True with a memory bank

Hi there, I’m trying to reproduce a training procedure with a “memory bank”, where, at each training step, a fixed number of feature embeddings are dequeued from a fixed-size list and used in loss computation, and then the feature embeddings of the current training data are pushed into the back of the list. The problem is that if I don’t detach the feature embeddings before I push them into the memory bank, I get the error “trying to backprop through the computation graph twice”, which of course makes sense as the features are from an earlier version of the graph. But is there any way to do this without setting retain_graph = True, which runs out of memory very quickly, or detaching the samples in the memory bank, which defeats the point of training this way?

Thank you!

1 Like

What’s the algorithm you’re attempting to reproduce?

Would checkpointing be useful in your situation?

Edit: actually, I think you might be doing something conceptually wrong here if you need to backward twice through the graph.

Any progress on this? I am running into the same issue. I am trying to dynamically increase the examples for negative pairing in a NCE loss using a memory bank that collects sample latent embeddings so I can reduce the mini batch size that is passed through the model.