GPU memory usage keep increasing per every batch

I am getting trouble while using pytorch for cuda out-of-memory.
GPU memory usage increases per every batch and it finally get a cuda out-of-memory error.
Following code snippets are my dataset implementation code.

class MOMARetrievalEvalDataset(Dataset):
    def __init__(self, cfg):
        super().__init__()
        
        self.cfg = cfg
        
        # base annotation data (preprocessed)
        with open("dataset/moma_val.ndjson", "r") as f:
            self.anno = ndjson.load(f)
            
        # sub-activity caption embedding lookup
        self.caption_emb_lookup = torch.load("dataset/moma_caption_emb.pt")
        
        self._prepare_batches()
        
    def _prepare_batches(self):
        self.batches = []
        for anno in self.anno:
            batch = {
                "src_video_id": anno["video_id"], # video id e.g. '-49z-lj8eYQ'
                "src_activity_id": anno["activity_id"], # activity class id e.g. 2
                "src_activity_name": anno["activity_name"], # activity name e.g. "basketball game"
                "trg_video_ids": [x["video_id"] for x in self.anno if anno["video_id"] != x["video_id"]],
                "trg_activity_ids": [x["activity_id"] for x in self.anno if anno["activity_id"] != x["activity_id"]],
                "trg_activity_names": [x["activity_name"] for x in self.anno if anno["activity_name"] != x["activity_name"]],
            }
            
            proxy_similarities = []
            for trg_vid in batch["trg_video_ids"]:
                emb1 = torch.tensor(self.caption_emb_lookup[anno["video_id"]])
                emb1 = F.normalize(emb1, dim=-1)
                emb2 = torch.tensor(self.caption_emb_lookup[trg_vid])
                emb2 = F.normalize(emb2, dim=-1)
                proxy_similarities.append(torch.mm(emb1, emb2.t()).mean())   
                
            batch["proxy_similarities"] = torch.tensor(proxy_similarities)
            
            self.batches.append(batch)
        
    def load_graph(self, vid):
        g_df = pd.read_csv(
            os.path.join(self.cfg.DATASET.moma.path, "graphs", f"{vid}.csv")
        )
        
        src_l = g_df.src.values
        dst_l = g_df.trg.values
        e_idx_l = g_df.idx.values
        ts_l = g_df.ts.values
        
        n_feat = np.load( # num_nodes x d (node features)
            os.path.join(self.cfg.DATASET.moma.path, "graphs", f"{vid}_node.npy")
        )
        e_feat = np.load( # num_edges x d (edge features)
            os.path.join(self.cfg.DATASET.moma.path, "graphs", f"{vid}_edge.npy")
        )
        
        n_feat_th = torch.nn.Parameter(torch.from_numpy(n_feat.astype(np.float32)), requires_grad=False)
        e_feat_th = torch.nn.Parameter(torch.from_numpy(e_feat.astype(np.float32)), requires_grad=False)
        
        node_raw_embed = torch.nn.Embedding.from_pretrained(n_feat_th, padding_idx=0, freeze=True).to("cuda")
        edge_raw_embed = torch.nn.Embedding.from_pretrained(e_feat_th, padding_idx=0, freeze=True).to("cuda")
        
        # full adjacency list
        max_idx = max(src_l.max(), dst_l.max())
        full_adj_list = [[] for _ in range(max_idx + 1)]
        for src, dst, eidx, ts in zip(src_l, dst_l, e_idx_l, ts_l):
            full_adj_list[src].append((dst, eidx, ts))
            # self.full_adj_list[dst].append((src, eidx, ts)) # if undirected
            
        # for k-hop neighbor graph search
        ngh_finder = NeighborFinder(
            full_adj_list, bias=1e-5, use_cache=False, sample_method="binary" # TODO: args to config!
        )
        
        graph = {
            "src_l": src_l,
            "dst_l": dst_l,
            "e_idx_l": e_idx_l,
            "ts_l": ts_l,
            "node_raw_embed": node_raw_embed,
            "edge_raw_embed": edge_raw_embed,
            "ngh_finder": ngh_finder,
        }
        
        return graph

    def __len__(self):
        return len(self.batches)
        
    def __getitem__(self, idx):
        batch = self.batches[idx] # batch size is always 1 for eval
        
        src_graph = self.load_graph(batch["src_video_id"])
        batch["src_graph"] = src_graph
        
        trg_graphs = [
            self.load_graph(trg_vid) for trg_vid in batch["trg_video_ids"]
        ]
        batch["trg_graphs"] = trg_graphs
        
        return batch

and dataloader is like follwing.

val_dataloader = DataLoader(
        dataset=val_dataset,
        batch_size=1, # always 1 for evaluation
        shuffle=False,
        num_workers=0, 
        collate_fn=val_collate_fn,
    )

It also gets the same problem when just iterating dataloader like

for batch in val_dataloader:
    ...

Please help me…

Based on your code snippet it seems you are assigning CUDATensors to batch["src_graph"] and batch["trg_graphs"] and I would assume these are kept alive in the internal self.batches list (you could double check it by trying to access e.g. self.batches[idx-1] in a later iteration.
In this case the memory increase would be expected since you are storing data on the GPU.

Oh I have found that self.batches[idx-1] still contains “src_graph” and “trg_graphs”!
I think I need to use something like deep copy to get self.batches[idx].
I really appreciate your help.
Thank you :slight_smile: