Memory leak in loading MRI data for neural network

I am loading data for a variational autoencoder and the memory used continues to rise eventually killing the process. The memory always increases at next(iter(generator)) but never goes down as different data is loaded.

I am loading MRI data and am loading five scans at a time and then pulling batches of size 128 voxels from these five scans. Each voxel is pulled in with its neighbouring 6 voxels giving a patch of voxels for use in the network.

To try and find this memory leak I have the rest of the model commented out so only the data is being loaded and not used for anything.

I have attempted a number of memory profiling methods but cannot find the cause of this increase in memory between loading the data. Why is the memory not being freed when a different dataset is loaded?

I’ve added gc.collect each epoch which i believe helped with a reduced network but with everything added there is still an issue. I have used some memory profiling but cannot seem to track down the actual issue. Previously I had tqdm in the loop to track progress which removing did seem to prevent the leak but there must be something else. Any help would be appreciated.

def train_model(vae_model, adversarial_model, train_scan_set, test_scan_set, ml_def, verbose=False):

params = {'batch_size': batch_size, 

    'shuffle': True, 

    'num_workers': 0} 


criterion = torch.nn.MSELoss(reduction='sum')  

optimizer = torch.optim.Adam(vae_model.parameters(), lr=learning_rate)  


saved_batch_data = [] 

saved_batch_labels = []  

for epoch in range(n_epochs): 

    epoch_start_date = datetime.utcnow() 

    epoch_loss = 0.0 

    mini_scan_set_counter = 0 

    for mini_scan_set_id in range(len(mini_scan_set_lists)): 

         mini_scan_set = mini_scan_set_lists[mini_scan_set_id] 


        mini_scan_set_counter += 5 

        batch_counter = 0 

        if True or mini_scan_set_id == 0: 

            datasets = [] 

        for dataset in datasets: 



        generators = [] 

        for mini_scan_id, scan in enumerate(mini_scan_set): 

            scan_dir = scan.get_storage_address() 

            dataset = Dataset(scan) 

            assert len(dataset) > 0, ["len is 0", scan.scan_name] 


            training_dataset_generator =, **params) 

            generators.append(training_dataset_generator) # list of 5 scans' dataset 

        scan_loss = 0.0 


        i = 0 


        max_generator_len = np.max(list(map(len, generators))) 

        for itera in range(max_generator_len): 

            saved_batch_data = [] 

            saved_batch_labels = [] 

            saved_centre_voxel_intensities = [] 

            n_batches = 0  

            for generator, dataset in zip(generators, datasets): 

                batch_i = itera 

                while batch_i >= len(generator): 

                    batch_i -= len(generator) 

                batch, labels, centre_voxel_intensities = next(iter(generator))  




                n_batches += 1 

            mini_epoch_loss = 0 

            for batch_id in range(n_batches): 

                saved_batch = saved_batch_data[batch_id] 

                saved_labels = saved_batch_labels[batch_id] 

                centre_voxel_intensities = saved_centre_voxel_intensities[batch_id] 

                saved_dataset = datasets[batch_id] 

                batch_counter += 1 

                local_batch, local_site_id, local_centre_voxel =,, ### 


                kl_loss = 0.5 * torch.sum(torch.exp(z_var) + z_mu**2 - 1.0 - z_var)  

                coefficient_reconstruction_loss = criterion(reconstructed_coefficients, local_batch[:, :coefficient_dimension_output])  

                loss = 1e-5*kl_loss + 1e-32*coefficient_reconstruction_loss  

                mini_epoch_loss += loss  

                scan_loss += loss  

                epoch_loss += loss  





                batch.grad = None 

                loss.grad = None 

    n = gc.collect() 

    epoch_end_date = datetime.utcnow() 

class Dataset(

     def __del__(self): 

    self.dti_nifti = None 

def __len__(self): 

    return self.len 

def __getitem__(self, i): 

    data = self.dti_nifti.get_fdata() 

    voxel_index = self.nii_indices[i] 

    x = voxel_index[0] 

    y = voxel_index[1] 

    z = voxel_index[2] 

    patch_indices = [ 

        [x, y, z], 

        [x+1, y, z], 

        [x-1, y, z], 

        [x, y+1, z], 

        [x, y-1, z], 

        [x, y, z+1], 

        [x, y, z-1],  


    output = [] 

    centre_voxel_info_fill = 0 

    for x, y, z in patch_indices: 

        mean_bzero_val = np.mean(data[x,y,z][self.bzero_volume_ids]) 

        for bvecs, intensity_vals in [(self.bvecs_array[self.non_zero_bvals_index], data[x,y,z,self.non_zero_bvals_index])]: 

            vecs_x_orig = [bvec[0] for bvec in bvecs] 

            vecs_y_orig = [bvec[1] for bvec in bvecs] 

            vecs_z_orig = [bvec[2] for bvec in bvecs] 

            vecs_x_neg = [-bvec[0] for bvec in bvecs] 

            vecs_y_neg = [-bvec[1] for bvec in bvecs] 

            vecs_z_neg = [-bvec[2] for bvec in bvecs] 

            vecs_x = vecs_x_orig + vecs_x_neg 

            vecs_y = vecs_y_orig + vecs_y_neg 

            vecs_z = vecs_z_orig + vecs_z_neg 

            cart_space = list(map(sh_convert.cart2sph, (vecs_x), (vecs_y), (vecs_z))) 

            angles_lat = [ang[0] for ang in cart_space] 

            angles_lon = [ang[1] for ang in cart_space] 

            if centre_voxel_info_fill == 0: 

                centre_voxel_int_values = intensity_vals 

            intensity_vals = np.concatenate([intensity_vals, intensity_vals]) 

            sh_coefficients, sh_error = pysh.shtools.SHExpandLSQ(intensity_vals, angles_lat, angles_lon, spherical_harmonic_degree) 

            sh_coefficients = sh_convert.reshape_sh_coeffs(sh_coefficients, degree=spherical_harmonic_degree) 

            sh_coefficients = sh_coefficients / coefficient_normalisation_factor 



        centre_voxel_info_fill += 1 


    values_array = np.array(output) 

    one_hot = self.create_one_hot_matrix()  

    return np.float32(values_array), np.float32(one_hot), np.float32(centre_voxel_int_values) 

def create_one_hot_matrix(self): 

    one_hot = np.zeros((len(scan_set_site_list))) 

    site_index = scan_set_site_list.index(self.site_id) 

    one_hot[site_index] = 1 

    one_hot_torch = torch.from_numpy(one_hot) 

    one_hot_torch = one_hot_torch.float() 

    return one_hot_torch