Computing network embeddings, avoid memory leak

I am dealing with variable sized input, and most instances are too large to fit in memory, even using multiple GPUs. I therefore use one network to compute a reduced representation (embedding) which I then forward to the prediction head (basically a form of gradient checkpointing).

There are surely many ways to implement this.
I am particularly surprised with a memory leak I encountered in the following approach:
        with torch.no_grad():
            instances = defaultdict(list)
            for batch in dataloader:
                img = batch["img"].to(self.device)
                batch["embedding"] =
                for key in patchbatch:

I believe that is not an in-place operation and results in a copy. So when I later collect batch[“img”] and batch[“embedding”] for a later backwards pass, they do not live on the GPU.

At which place am I blocking the GPUs memory from being freed?