Computing network embeddings, avoid memory leak

I am dealing with variable sized input, and most instances are too large to fit in memory, even using multiple GPUs. I therefore use one network to compute a reduced representation (embedding) which I then forward to the prediction head (basically a form of gradient checkpointing).

There are surely many ways to implement this.
I am particularly surprised with a memory leak I encountered in the following approach:

        self.network.eval()
        with torch.no_grad():
            instances = defaultdict(list)
            for batch in dataloader:
                img = batch["img"].to(self.device)
                batch["embedding"] = self.network(img).detach().cpu()
                for key in patchbatch:
                    instances[key].append(batch[key])

I believe that tensor.to.self(device) is not an in-place operation and results in a copy. So when I later collect batch[“img”] and batch[“embedding”] for a later backwards pass, they do not live on the GPU.

At which place am I blocking the GPUs memory from being freed?