PyTorch DataGenerator extremely slow

Hi all,

Iam new to PyTorch and wanted to convert one of my Keras projects, where I use patch wise training for segmentation purposes.
All preprocessing is done before training and the preprocessed volumes are saved to disk as pickled files.

The Datagenerator then reads a pickled file (a whole volume) and randomly crops a patch out of it.

Everything is running fine but compared to keras the PyTorch generator is super slow.
Dataloading in keras takes about 0.5sec, while pytorch needs nearly 3sec.

Now Iam wondering what Iam doing wrong?

Here is my Generator:

#==============================================================================#
#  
#  This program is free software: you can redistribute it and/or modify        #
#  it under the terms of the GNU General Public License as published by        #
#  the Free Software Foundation, either version 3 of the License, or           #
#  (at your option) any later version.                                         #
#                                                                              #
#  This program is distributed in the hope that it will be useful,             #
#  but WITHOUT ANY WARRANTY; without even the implied warranty of              #
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the               #
#  GNU General Public License for more details.                                #
#                                                                              #
#  You should have received a copy of the GNU General Public License           #
#  along with this program.  If not, see <http://www.gnu.org/licenses/>.       #
#==============================================================================#
#-----------------------------------------------------#
#                   Library imports                   #
#-----------------------------------------------------#
#External libraries
import math
import numpy as np
from torch.utils.data import Dataset    
import torch
#-----------------------------------------------------#
#                 Pytorch Data Generator                #
#-----------------------------------------------------#
# Data Generator for generating batches (WITH-/OUT segmentation)
## Returns a batch containing one or multiple images for training/prediction
class SegmentationDataset(Dataset):
    # Class Initialization
    def __init__(self, sample_list, preprocessor, training=False,
                 validation=False, shuffle=False, iterations=None):
        # Parse sample list
        if isinstance(sample_list, list) : self.sample_list = sample_list.copy()
        elif type(sample_list).__module__ == np.__name__ :
            self.sample_list = sample_list.tolist()
        else : raise ValueError("Sample list have to be a list or numpy array!")
        # Create a working environment from the handed over variables
        self.preprocessor = preprocessor
        self.training = training
        self.validation = validation
        self.shuffle = shuffle
        self.iterations = iterations
        self.batch_queue = []
        # If samples with subroutines should be preprocessed -> do it now
        if preprocessor.prepare_subfunctions:
            preprocessor.run_subfunctions(sample_list, training)
        # If batches should be prepared before runtime -> do it now
        #if preprocessor.prepare_batches:
            #batches_count = preprocessor.run(sample_list, training, validation)
            #self.batchpointers = list(range(0, batches_count+1))
        #elif not training:
            #self.batch_queue = preprocessor.run(sample_list, False, False)

    # Return the next batch for associated index
    def __getitem__(self, idx):
        # Load a batch by generating it or by loading an already prepared
        batch = self.generate_batch(idx)
        # Return the batch containing only an image or an image and segmentation
        if self.training:   
            img = np.moveaxis( batch[0][0],-1, 0)     
            seg = np.moveaxis( batch[0][1], -1, 0)    
            return torch.from_numpy(img).float(), torch.from_numpy(seg).float()
        else:
            img = np.moveaxis( batch[0][0], -1, 0)        
            seg = np.moveaxis( batch[0][1], -1, 0)    
            return torch.from_numpy(img).float(), torch.from_numpy(seg).float()

    # Return the number of batches for one epoch
    def __len__(self):
        return len(self.sample_list)
       
    
    # Generate a batch during runtime
    def generate_batch(self, idx):

        sample = self.preprocessor.run(self.sample_list[idx], self.training, self.validation)

        return sample

       

Thanks for any advice,

cheers,

Michael