Memory leak in loading MRI data for neural network

I am loading data for a variational autoencoder and the memory used continues to rise eventually killing the process. The memory always increases at next(iter(generator)) but never goes down as different data is loaded.

I am loading MRI data and am loading five scans at a time and then pulling batches of size 128 voxels from these five scans. Each voxel is pulled in with its neighbouring 6 voxels giving a patch of voxels for use in the network.

To try and find this memory leak I have the rest of the model commented out so only the data is being loaded and not used for anything.

I have attempted a number of memory profiling methods but cannot find the cause of this increase in memory between loading the data. Why is the memory not being freed when a different dataset is loaded?

I’ve added gc.collect each epoch which i believe helped with a reduced network but with everything added there is still an issue. I have used some memory profiling but cannot seem to track down the actual issue. Previously I had tqdm in the loop to track progress which removing did seem to prevent the leak but there must be something else. Any help would be appreciated.

def train_model(vae_model, adversarial_model, train_scan_set, test_scan_set, ml_def, verbose=False):

params = {'batch_size': batch_size, 

    'shuffle': True, 

    'num_workers': 0} 

 

criterion = torch.nn.MSELoss(reduction='sum')  

optimizer = torch.optim.Adam(vae_model.parameters(), lr=learning_rate)  

  

saved_batch_data = [] 

saved_batch_labels = []  



for epoch in range(n_epochs): 

    epoch_start_date = datetime.utcnow() 

    epoch_loss = 0.0 

    mini_scan_set_counter = 0 



    for mini_scan_set_id in range(len(mini_scan_set_lists)): 

         mini_scan_set = mini_scan_set_lists[mini_scan_set_id] 

         

        mini_scan_set_counter += 5 



        batch_counter = 0 



        if True or mini_scan_set_id == 0: 

            datasets = [] 

        for dataset in datasets: 

            del(dataset) 

         

        generators = [] 



        for mini_scan_id, scan in enumerate(mini_scan_set): 

            scan_dir = scan.get_storage_address() 

            dataset = Dataset(scan) 

            assert len(dataset) > 0, ["len is 0", scan.scan_name] 

            datasets.append(dataset) 



            training_dataset_generator = torch.utils.data.DataLoader(dataset, **params) 

            generators.append(training_dataset_generator) # list of 5 scans' dataset 



        scan_loss = 0.0 

         

        i = 0 

        

        max_generator_len = np.max(list(map(len, generators))) 



        for itera in range(max_generator_len): 

            saved_batch_data = [] 

            saved_batch_labels = [] 

            saved_centre_voxel_intensities = [] 

            n_batches = 0  



            for generator, dataset in zip(generators, datasets): 

                batch_i = itera 

                while batch_i >= len(generator): 

                    batch_i -= len(generator) 



                batch, labels, centre_voxel_intensities = next(iter(generator))  



                saved_batch_data.append(batch) 

                saved_batch_labels.append(labels) 

                saved_centre_voxel_intensities.append(centre_voxel_intensities) 

                n_batches += 1 



            mini_epoch_loss = 0 

            for batch_id in range(n_batches): 

                saved_batch = saved_batch_data[batch_id] 

                saved_labels = saved_batch_labels[batch_id] 

                centre_voxel_intensities = saved_centre_voxel_intensities[batch_id] 

                saved_dataset = datasets[batch_id] 

                batch_counter += 1 



                local_batch, local_site_id, local_centre_voxel = saved_batch.to(ml_def.device), saved_labels.to(ml_def.device), centre_voxel_intensities.to(ml_def.device) ### 

  

                kl_loss = 0.5 * torch.sum(torch.exp(z_var) + z_mu**2 - 1.0 - z_var)  

                coefficient_reconstruction_loss = criterion(reconstructed_coefficients, local_batch[:, :coefficient_dimension_output])  



                loss = 1e-5*kl_loss + 1e-32*coefficient_reconstruction_loss  



                mini_epoch_loss += loss  

                scan_loss += loss  

                epoch_loss += loss  

              

                optimizer.zero_grad()  

                loss.backward()  

                optimizer.step()  



                batch.grad = None 

                loss.grad = None 



    n = gc.collect() 

    epoch_end_date = datetime.utcnow() 

class Dataset(torch.utils.data.Dataset):

     def __del__(self): 

    self.dti_nifti = None 



def __len__(self): 

    return self.len 



def __getitem__(self, i): 

    data = self.dti_nifti.get_fdata() 



    voxel_index = self.nii_indices[i] 

    x = voxel_index[0] 

    y = voxel_index[1] 

    z = voxel_index[2] 



    patch_indices = [ 

        [x, y, z], 

        [x+1, y, z], 

        [x-1, y, z], 

        [x, y+1, z], 

        [x, y-1, z], 

        [x, y, z+1], 

        [x, y, z-1],  

    ] 



    output = [] 



    centre_voxel_info_fill = 0 



    for x, y, z in patch_indices: 

        mean_bzero_val = np.mean(data[x,y,z][self.bzero_volume_ids]) 



        for bvecs, intensity_vals in [(self.bvecs_array[self.non_zero_bvals_index], data[x,y,z,self.non_zero_bvals_index])]: 

            vecs_x_orig = [bvec[0] for bvec in bvecs] 

            vecs_y_orig = [bvec[1] for bvec in bvecs] 

            vecs_z_orig = [bvec[2] for bvec in bvecs] 

            vecs_x_neg = [-bvec[0] for bvec in bvecs] 

            vecs_y_neg = [-bvec[1] for bvec in bvecs] 

            vecs_z_neg = [-bvec[2] for bvec in bvecs] 

            vecs_x = vecs_x_orig + vecs_x_neg 

            vecs_y = vecs_y_orig + vecs_y_neg 

            vecs_z = vecs_z_orig + vecs_z_neg 



            cart_space = list(map(sh_convert.cart2sph, (vecs_x), (vecs_y), (vecs_z))) 

            angles_lat = [ang[0] for ang in cart_space] 

            angles_lon = [ang[1] for ang in cart_space] 



            if centre_voxel_info_fill == 0: 

                centre_voxel_int_values = intensity_vals 



            intensity_vals = np.concatenate([intensity_vals, intensity_vals]) 



            sh_coefficients, sh_error = pysh.shtools.SHExpandLSQ(intensity_vals, angles_lat, angles_lon, spherical_harmonic_degree) 



            sh_coefficients = sh_convert.reshape_sh_coeffs(sh_coefficients, degree=spherical_harmonic_degree) 

            sh_coefficients = sh_coefficients / coefficient_normalisation_factor 



            output.extend([mean_bzero_val]) 

            output.extend(sh_coefficients) 



        centre_voxel_info_fill += 1 

     

    values_array = np.array(output) 



    one_hot = self.create_one_hot_matrix()  

    return np.float32(values_array), np.float32(one_hot), np.float32(centre_voxel_int_values) 



def create_one_hot_matrix(self): 

    one_hot = np.zeros((len(scan_set_site_list))) 

    site_index = scan_set_site_list.index(self.site_id) 

    one_hot[site_index] = 1 

    one_hot_torch = torch.from_numpy(one_hot) 

    one_hot_torch = one_hot_torch.float() 



    return one_hot_torch