I am loading data for a variational autoencoder and the memory used continues to rise eventually killing the process. The memory always increases at next(iter(generator)) but never goes down as different data is loaded.
I am loading MRI data and am loading five scans at a time and then pulling batches of size 128 voxels from these five scans. Each voxel is pulled in with its neighbouring 6 voxels giving a patch of voxels for use in the network.
To try and find this memory leak I have the rest of the model commented out so only the data is being loaded and not used for anything.
I have attempted a number of memory profiling methods but cannot find the cause of this increase in memory between loading the data. Why is the memory not being freed when a different dataset is loaded?
I’ve added gc.collect each epoch which i believe helped with a reduced network but with everything added there is still an issue. I have used some memory profiling but cannot seem to track down the actual issue. Previously I had tqdm in the loop to track progress which removing did seem to prevent the leak but there must be something else. Any help would be appreciated.
def train_model(vae_model, adversarial_model, train_scan_set, test_scan_set, ml_def, verbose=False):
params = {'batch_size': batch_size,
'shuffle': True,
'num_workers': 0}
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.Adam(vae_model.parameters(), lr=learning_rate)
saved_batch_data = []
saved_batch_labels = []
for epoch in range(n_epochs):
epoch_start_date = datetime.utcnow()
epoch_loss = 0.0
mini_scan_set_counter = 0
for mini_scan_set_id in range(len(mini_scan_set_lists)):
mini_scan_set = mini_scan_set_lists[mini_scan_set_id]
mini_scan_set_counter += 5
batch_counter = 0
if True or mini_scan_set_id == 0:
datasets = []
for dataset in datasets:
del(dataset)
generators = []
for mini_scan_id, scan in enumerate(mini_scan_set):
scan_dir = scan.get_storage_address()
dataset = Dataset(scan)
assert len(dataset) > 0, ["len is 0", scan.scan_name]
datasets.append(dataset)
training_dataset_generator = torch.utils.data.DataLoader(dataset, **params)
generators.append(training_dataset_generator) # list of 5 scans' dataset
scan_loss = 0.0
i = 0
max_generator_len = np.max(list(map(len, generators)))
for itera in range(max_generator_len):
saved_batch_data = []
saved_batch_labels = []
saved_centre_voxel_intensities = []
n_batches = 0
for generator, dataset in zip(generators, datasets):
batch_i = itera
while batch_i >= len(generator):
batch_i -= len(generator)
batch, labels, centre_voxel_intensities = next(iter(generator))
saved_batch_data.append(batch)
saved_batch_labels.append(labels)
saved_centre_voxel_intensities.append(centre_voxel_intensities)
n_batches += 1
mini_epoch_loss = 0
for batch_id in range(n_batches):
saved_batch = saved_batch_data[batch_id]
saved_labels = saved_batch_labels[batch_id]
centre_voxel_intensities = saved_centre_voxel_intensities[batch_id]
saved_dataset = datasets[batch_id]
batch_counter += 1
local_batch, local_site_id, local_centre_voxel = saved_batch.to(ml_def.device), saved_labels.to(ml_def.device), centre_voxel_intensities.to(ml_def.device) ###
kl_loss = 0.5 * torch.sum(torch.exp(z_var) + z_mu**2 - 1.0 - z_var)
coefficient_reconstruction_loss = criterion(reconstructed_coefficients, local_batch[:, :coefficient_dimension_output])
loss = 1e-5*kl_loss + 1e-32*coefficient_reconstruction_loss
mini_epoch_loss += loss
scan_loss += loss
epoch_loss += loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
batch.grad = None
loss.grad = None
n = gc.collect()
epoch_end_date = datetime.utcnow()
class Dataset(torch.utils.data.Dataset):
def __del__(self):
self.dti_nifti = None
def __len__(self):
return self.len
def __getitem__(self, i):
data = self.dti_nifti.get_fdata()
voxel_index = self.nii_indices[i]
x = voxel_index[0]
y = voxel_index[1]
z = voxel_index[2]
patch_indices = [
[x, y, z],
[x+1, y, z],
[x-1, y, z],
[x, y+1, z],
[x, y-1, z],
[x, y, z+1],
[x, y, z-1],
]
output = []
centre_voxel_info_fill = 0
for x, y, z in patch_indices:
mean_bzero_val = np.mean(data[x,y,z][self.bzero_volume_ids])
for bvecs, intensity_vals in [(self.bvecs_array[self.non_zero_bvals_index], data[x,y,z,self.non_zero_bvals_index])]:
vecs_x_orig = [bvec[0] for bvec in bvecs]
vecs_y_orig = [bvec[1] for bvec in bvecs]
vecs_z_orig = [bvec[2] for bvec in bvecs]
vecs_x_neg = [-bvec[0] for bvec in bvecs]
vecs_y_neg = [-bvec[1] for bvec in bvecs]
vecs_z_neg = [-bvec[2] for bvec in bvecs]
vecs_x = vecs_x_orig + vecs_x_neg
vecs_y = vecs_y_orig + vecs_y_neg
vecs_z = vecs_z_orig + vecs_z_neg
cart_space = list(map(sh_convert.cart2sph, (vecs_x), (vecs_y), (vecs_z)))
angles_lat = [ang[0] for ang in cart_space]
angles_lon = [ang[1] for ang in cart_space]
if centre_voxel_info_fill == 0:
centre_voxel_int_values = intensity_vals
intensity_vals = np.concatenate([intensity_vals, intensity_vals])
sh_coefficients, sh_error = pysh.shtools.SHExpandLSQ(intensity_vals, angles_lat, angles_lon, spherical_harmonic_degree)
sh_coefficients = sh_convert.reshape_sh_coeffs(sh_coefficients, degree=spherical_harmonic_degree)
sh_coefficients = sh_coefficients / coefficient_normalisation_factor
output.extend([mean_bzero_val])
output.extend(sh_coefficients)
centre_voxel_info_fill += 1
values_array = np.array(output)
one_hot = self.create_one_hot_matrix()
return np.float32(values_array), np.float32(one_hot), np.float32(centre_voxel_int_values)
def create_one_hot_matrix(self):
one_hot = np.zeros((len(scan_set_site_list)))
site_index = scan_set_site_list.index(self.site_id)
one_hot[site_index] = 1
one_hot_torch = torch.from_numpy(one_hot)
one_hot_torch = one_hot_torch.float()
return one_hot_torch