Hi guys. I have a big memory leak within this loop:
It’s “within” PyTorch (well I’m 99% sure) because pympler.tracker.print_diff() is showing nothing from loop to loop, but the memory usage is going up (and by a lot - 30MB or so per itera loop)
Can anyone help me find out what the issue is?
for itera in range(max_generator_len):
tracker.print_diff()
saved_batch_data = []
saved_batch_labels = []
saved_centre_voxel_intensities = []
n_batches = 0
for generator, dataset in zip(generators, datasets):
batch_i = itera
while batch_i >= len(generator):
batch_i -= len(generator)
batch, labels, centre_voxel_intensities = next(iter(generator))
saved_batch_data.append(batch)
saved_batch_labels.append(labels)
saved_centre_voxel_intensities.append(centre_voxel_intensities)
n_batches += 1
#
for batch_id in range(n_batches):
batch = saved_batch_data[batch_id]
labels = saved_batch_labels[batch_id]
centre_voxel_intensities = saved_centre_voxel_intensities[batch_id]
dataset = datasets[batch_id]
batch_counter += 1
# transfer to gpu
local_batch, local_site_id, local_centre_voxel = batch.to(ml_def.device), labels.to(ml_def.device), centre_voxel_intensities.to(ml_def.device)
# VAE
reconstructed_coefficients, z_mu, z_var = vae_model(local_batch, local_site_id)
# VAE loss
kl_loss = 0.5 * torch.sum(torch.exp(z_var) + z_mu**2 - 1.0 - z_var)
coefficient_reconstruction_loss = criterion(reconstructed_coefficients, local_batch)
# reconstruct patch
# dataset.recon... changed to datasets[0].recon... because of the multipledatasets
dti_space_values = dataset.reconstruct_patch(reconstructed_coefficients).to(ml_def.device)
sh_to_dti_recon_loss = ((local_centre_voxel - dti_space_values[:,:len(dataset.non_zero_bvals_index)])**2).mean()
adversarial_loss_tracker = []
time_to_train_the_adversary = (batch_id == n_batches - 1)
# Adversarial
for _ in range(adversary_epochs_per_batch):
for saved_batch_iter in range(number_of_sites):
site_prediction = adversarial_model(saved_batch_data[saved_batch_iter].float().to(ml_def.device)) #, local_site_id)
# site_prediction = adversarial_model(reconstructed_coefficients) #, local_site_id)
# site_prediction = adversarial_model(dti_space_values) #, local_site_id)
# local_site_id_idx = torch.Tensor([torch.argmax(row) for row in local_site_id]).long()
# adversarial_loss = adversarial_criterion(site_prediction, local_site_id_idx)
local_site_id_idx = torch.Tensor([torch.argmax(row) for row in saved_batch_labels[saved_batch_iter]]).long().to(ml_def.device)
adversarial_loss = adversarial_criterion(site_prediction, local_site_id_idx)
adversarial_loss_tracker.append(adversarial_loss.data)
if time_to_train_the_adversary:
adversarial_optimizer.zero_grad()
adversarial_loss.backward()
adversarial_optimizer.step()
#
if time_to_train_the_adversary:
print(
"batch: ", batch_counter, "/", len(generators)*max_generator_len,
" scans: ", mini_scan_set_counter, "/", len(mini_scan_set_lists)*number_of_sites,
" epoch: ", epoch+1, "/", n_epochs,
" memory used: ", resource.getrusage(resource.RUSAGE_SELF).ru_maxrss,
sep="")
# if adversarial_loss_tracker != []:
# print(adversarial_loss_tracker)
loss = kl_loss + 1e-20*coefficient_reconstruction_loss + 1e-6*sh_to_dti_recon_loss - gamma*adversarial_loss
loss = loss.detach().requires_grad_(True)
scan_loss += loss
epoch_loss += loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i%100 == 0 and i>0:
print("loss for next 100th batch ", loss)
i += 1
batch.grad = None
loss.grad = None