Data Loader process Killed


Looks like I have the same problem that many other people encountered. But I can’t fix it.

When I try to load an image from train_dataloader, the process crash and returns: Killed

Here is the dataset code:

class ScansDataset(Dataset):
    """Dataset of aligned PET+MRI"""
    def __init__(self, preprocessing_parameters):
        dir_data_pet = os.path.join(preprocessing_parameters["path_for_preprocessed_data"], "PET")
        dir_data_mri = os.path.join(preprocessing_parameters["path_for_preprocessed_data"], "MRI")
        dir_data_labels = os.path.join(preprocessing_parameters["path_for_preprocessed_data"], "label")
        # Dir data should point to dir of npy files for the PET scans
        self.dir_data_pet = dir_data_pet
        # MRI
        self.dir_data_mri = dir_data_mri
        # Labels
        self.dir_data_labels = dir_data_labels
        assert len(os.listdir(dir_data_pet)) == len(os.listdir(dir_data_mri))

    def __len__(self):
        'denotes the total number of images'
        return len(self.list_IDs)

    def __getitem__(self, item):
        'Generates one sample of data'
        #select sample
        # path_mr = os.path.join(self.dir_data, 'MRI', self.list_IDs[item])
        # path_pet = os.path.join(self.dir_data, 'PET', self.list_IDs[item])

        name = self.list_IDs[item]
        pet = np.load(os.path.join(self.dir_data_pet, name + '.npy'))
        mri = np.load(os.path.join(self.dir_data_mri, name + '.npy'))
        label = np.load(os.path.join(self.dir_data_labels, name + '.npy'))
        pet = pet[:, :, :, np.newaxis]  #add one dimension, channel
        mri = mri[:, :, :, np.newaxis]  #add one dimension, channel
        label = np.transpose(label, [1, 2, 3, 0])
        pet = ToTensor()(pet)
        mri = ToTensor()(mri)
        label = ToTensor()(label)
        return name, pet, mri, label

    def generate_list_ID(self):
        list_PET = os.listdir(self.dir_data_pet)
        list_MRI = os.listdir(self.dir_data_mri)
        list_labels = [f for f in os.listdir(self.dir_data_labels) if not f.startswith('.')]
        assert [f[:-4] for f in list_PET] == [f[:-4] for f in list_MRI] #Remove the .npy at the end of the filename
        if list_labels != list_PET:
            print("       ----WARNING----")
        list_of_id = [f[:-4] for f in list_labels]
        self.list_IDs = list_of_id

And for the training code (among other things):

dataset = ScansDataset(parameters["preprocessing"])

train_size = int(0.8*len(dataset))
validation_size = len(dataset) - train_size

train_dataset, validation_dataset =, [train_size, validation_size])
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, drop_last=False)
val_loader = DataLoader(validation_dataset,  batch_size=1, shuffle=False, drop_last=False)

loss_fn = SoftDice_and_BCELoss()

## As defined in nnUNet
# optimizer = torch.optim.SGD(model.parameters(), lr=0.01, weight_decay=3e-5,
#                                          momentum=0.99, nesterov=True)
optimizer = torch.optim.Adam(model.parameters(), lr=parameters["training"]["learning_rate"])

training_loss = {} #Dict mapping epoch -> dict(name, loss)
validation_loss = {} 
epochs = parameters["training"]["epochs"]
print("Starting training")
for epoch in range(epochs):
    start_time = time.time()
    train_loss, validation_loss = {}, {}
    # Train
    for cnt, (name, pet, mri, label) in enumerate(train_loader):
        # mock_pet = torch.randn(2, 1, 64, 64, 128)
        # mock_mri = torch.randn(2, 1, 64, 64, 128)
        output = model(pet, mri)```

What I have tried so far:
- Running the code with `num_workers=0`
- Monitoring the memory on htop (Doesn't look like I run out of memory)
- The line failing is 'for cnt, (name, pet, mri, label) in enumerate(train_loader):'

I'm running the code on a remote server (Maybe that is an important information ?)
I start a session with `srun -p gpu --gpus 1 --pty bash`

I'm running out of ideas.. So any help would be much appreciated !
Thank you so much in advance

Is the process being killed immediately or after a few iterations? You might want to “instrument” your program and seeing if removing some data loading steps or inputs causes the crash to go away to see where the actual problem lies.