How to handle RAM OOM in DDP?

i tried to implmente the DDP on my pytorch code , the ressources i used

  1. 1 GPU , 2 process , each one with 1-2 CPU .
    i stile get and OOM killed befor moving the data to the GPU , so it appear that the CPU memory is full , and the problem is each processus load the data file inot the RAM , are there any solution to avoid that ? i found something about shaired memory but now solution found in pytorch
class DrugProteinDataset_batchN(Dataset):
    def __init__(self, dataframe, preloaded_data_file, device="cpu", normalize=True):
        """
        Dataset PyTorch pour la prédiction DTI.
        """
        self.device = device
        self.normalize = normalize
        self.preloaded_data = torch.load(preloaded_data_file, map_location="cpu")

        # Prétraitement des IDs
        dataframe['hsa_id'] = dataframe['hsa_id'].apply(lambda x: 'hsa_' + x[3:])

        # Filtrage : ne garde que les lignes avec des graphes valides
        valid_rows = []
        for _, row in dataframe.iterrows():
            drug_id = row['drug_id']
            hsa_id = row['hsa_id']
            if (
                self.preloaded_data["drug_graphs"].get(drug_id) is not None and 
                self.preloaded_data["protein_graphs"].get(hsa_id) is not None and
                self.preloaded_data["distance_embeddings"].get(hsa_id) is not None and
                self.preloaded_data["angular_embeddings"].get(hsa_id) is not None
            ):
                valid_rows.append(row)

        self.dataframe = pd.DataFrame(valid_rows).reset_index(drop=True)

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        drug_id = row['drug_id']
        hsa_id = row['hsa_id']
        label = torch.tensor(row['label'], dtype=torch.float, device=self.device)

        drug_graph = self._load_graph(self.preloaded_data["drug_graphs"].get(drug_id), self.normalize)
        protein_graph = self._load_graph(self.preloaded_data["protein_graphs"].get(hsa_id), self.normalize, is_protein=True)

        distance_embeddings = self.preloaded_data["distance_embeddings"].get(hsa_id)
        angular_embeddings = self.preloaded_data["angular_embeddings"].get(hsa_id)

        return drug_graph, protein_graph, label, distance_embeddings, angular_embeddings

    def _load_graph(self, graph, normalize=True, is_protein=False):
        if graph is None:
            return None
        graph = graph.clone()
        if is_protein:
            graph.x = graph.x[:, 3:]
            if self.normalize:
                graph.x = normalize_protein_features(graph.x)
        else:
            if self.normalize:
                graph.x = z_score_normalize_drug_features(graph.x)
        return graph

the preloaded file structure : preloaded_data = {
“drug_graphs”: drug_graphs,
“protein_graphs”: protein_graphs,
“distance_embeddings”: distance_embeddings,
“angular_embeddings”: angular_embeddings
} the graphs are in a PyG data and the embeddings are a dictionerries

Yes, lazily loading the dataset as explained in your previous post could reduce the memory usage and thus avoid the OOM if applicable.

yes i already applied , but the problem still appear in dataset with 17 000 samples but for 6000 and 900 samples it work , so i’m asking if there are a clear documentation about making a file in shaired memory so during the dataloading all process acess to it , i’m trying to do this causes i want to avoid using multi nodes solution

It seems you are still preloading data:

class DrugProteinDataset_batchN(Dataset):
    def __init__(self, dataframe, preloaded_data_file, device="cpu", normalize=True):
        ...
        self.preloaded_data = torch.load(preloaded_data_file, map_location="cpu")
        ...

so you might want to check if moving this to the __getitem__ could avoid the issue.