Big memory leak

Hi guys. I have a big memory leak within this loop:

It’s “within” PyTorch (well I’m 99% sure) because pympler.tracker.print_diff() is showing nothing from loop to loop, but the memory usage is going up (and by a lot - 30MB or so per itera loop)

Can anyone help me find out what the issue is?

for itera in range(max_generator_len):
                tracker.print_diff()

                saved_batch_data = []
                saved_batch_labels = []
                saved_centre_voxel_intensities = []
                n_batches = 0

                for generator, dataset in zip(generators, datasets):
                    batch_i = itera
                    while batch_i >= len(generator):
                        batch_i -= len(generator)

                    batch, labels, centre_voxel_intensities = next(iter(generator))

                    saved_batch_data.append(batch)
                    saved_batch_labels.append(labels)
                    saved_centre_voxel_intensities.append(centre_voxel_intensities)
                    n_batches += 1

                #

                for batch_id in range(n_batches):
                    batch = saved_batch_data[batch_id]
                    labels = saved_batch_labels[batch_id]
                    centre_voxel_intensities = saved_centre_voxel_intensities[batch_id]
                    dataset = datasets[batch_id]

                    batch_counter += 1

                    # transfer to gpu
                    local_batch, local_site_id, local_centre_voxel = batch.to(ml_def.device), labels.to(ml_def.device), centre_voxel_intensities.to(ml_def.device)

                    # VAE
                    reconstructed_coefficients, z_mu, z_var = vae_model(local_batch, local_site_id)

                    # VAE loss 
                    kl_loss = 0.5 * torch.sum(torch.exp(z_var) + z_mu**2 - 1.0 - z_var)
                    coefficient_reconstruction_loss = criterion(reconstructed_coefficients, local_batch)

                    # reconstruct patch
                    # dataset.recon... changed to datasets[0].recon... because of the multipledatasets
                    dti_space_values = dataset.reconstruct_patch(reconstructed_coefficients).to(ml_def.device)

                    sh_to_dti_recon_loss = ((local_centre_voxel - dti_space_values[:,:len(dataset.non_zero_bvals_index)])**2).mean()

                    adversarial_loss_tracker = []

                    time_to_train_the_adversary = (batch_id == n_batches - 1)

                    # Adversarial

                    for _ in range(adversary_epochs_per_batch):
                        for saved_batch_iter in range(number_of_sites):
                            site_prediction = adversarial_model(saved_batch_data[saved_batch_iter].float().to(ml_def.device)) #, local_site_id)
                            # site_prediction = adversarial_model(reconstructed_coefficients) #, local_site_id)
                            # site_prediction = adversarial_model(dti_space_values) #, local_site_id)

                            # local_site_id_idx = torch.Tensor([torch.argmax(row) for row in local_site_id]).long()
                            # adversarial_loss = adversarial_criterion(site_prediction, local_site_id_idx)
                            local_site_id_idx = torch.Tensor([torch.argmax(row) for row in saved_batch_labels[saved_batch_iter]]).long().to(ml_def.device)
                            adversarial_loss = adversarial_criterion(site_prediction, local_site_id_idx)

                            adversarial_loss_tracker.append(adversarial_loss.data)

                            if time_to_train_the_adversary:
                                adversarial_optimizer.zero_grad()
                                adversarial_loss.backward()
                                adversarial_optimizer.step()
                    
                    #

                    if time_to_train_the_adversary:
                        print(
                            "batch: ", batch_counter, "/", len(generators)*max_generator_len,
                            "  scans: ", mini_scan_set_counter, "/", len(mini_scan_set_lists)*number_of_sites,
                            "  epoch: ", epoch+1, "/", n_epochs, 
                            "  memory used: ", resource.getrusage(resource.RUSAGE_SELF).ru_maxrss,
                            sep="")

                    # if adversarial_loss_tracker != []:
                    #     print(adversarial_loss_tracker) 
                     
                    loss = kl_loss + 1e-20*coefficient_reconstruction_loss + 1e-6*sh_to_dti_recon_loss - gamma*adversarial_loss 
                    loss = loss.detach().requires_grad_(True)

                    scan_loss += loss
                    epoch_loss += loss
                    
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                    if i%100 == 0 and i>0:
                        print("loss for next 100th batch ", loss)
                    i += 1

                    batch.grad = None
                    loss.grad = None

Hi,

Are you sure that you don’t keep some states from one iteration to the next?

Also when you do loss = loss.detach().requires_grad_(True) you break the link between the loss value and the rest of the net. So doing loss.backward() is not going to do anything.