CUDA out of memory - sudden large allocation of memory

Dear all,

I can not figure out how to get rid of the out of memory error, with a sudden and unexplainable large memory request (see below):
RuntimeError: CUDA out of memory. Tried to allocate 4.21 GiB (GPU 0; 8.00 GiB total capacity; 128.69 MiB already allocated; 1.92 GiB free; 4.34 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

To quickly introduce my pipeline, I am doing some link prediction on a Heterogeneous Graph using PyTorch-Geometric. I am working on a single graph (batch_size = 1 - I just use the same graph with edges differently labelled between batches).
When I tried to track the memory allocation, I noticed some very strange behaviour: it seems some epochs are computed without issues while the memory exploded during others, even when I am manually cleaning the memory between epochs:

for epoch in range(params.RUNPARAMS.epochs):
	losses = torch.ones(batches.nb_batches)

	for idx, batch in enumerate(batches):
		batch = batch.to(device)
		print(f'Before train: allocated {round(torch.cuda.memory_allocated(0)/1024**3,6)}GB')
		losses[idx] = model.train_graph(batch, optimizer)
		print(f'After train: allocated {round(torch.cuda.memory_allocated(0)/1024**3,6)}GB')
		batch = batch.to('cpu')
	# Manually clean memory
	torch.cuda.empty_cache()

    # Final pred, after all the batches has been runned
    print(f'Before test: allocated {round(torch.cuda.memory_allocated(0)/1024**3,6)}GB')
	data = data.to(device)
	test_pred, test_loss = model.test_predict_graph(data, penalize_fn=False, set='test')
	data = data.to('cpu')
	torch.cuda.empty_cache()
    print(f'After test: allocated {round(torch.cuda.memory_allocated(0)/1024**3,6)}GB')

And here are the outputs:

Epoch 0...
Before train: allocated 0.095437GB
After train: allocated 0.107506GB
Before train: allocated 0.107506GB
After train: allocated 0.107506GB
Before train: allocated 0.107506GB
After train: allocated 0.107506GB
Before train: allocated 0.107506GB
After train: allocated 0.107506GB
Before train: allocated 0.107506GB
After train: allocated 0.107506GB
Before test: allocated 0.01663GB
After test: allocated 0.016637GB

Epoch 1...
Before train: allocated 0.111618GB
After train: allocated 0.111618GB
Before train: allocated 0.111618GB
After train: allocated 0.111618GB
Before train: allocated 0.111618GB
After train: allocated 0.111618GB
Before train: allocated 0.111618GB
After train: allocated 0.111618GB
Before train: allocated 0.111616GB
After train: allocated 0.111616GB
Before test: allocated 0.020755GB
After test: allocated 0.020762GB

Epoch 2...
Before train: allocated 0.115716GB
After train: allocated 0.115716GB
Before train: allocated 0.115716GB
Traceback (most recent call last):
  File "<stdin>", line 6, in <module>
  File "<stdin>", line 41, in run_pda_pipeline
  File "D:\Programmation\GP2-PLIP\src\model.py", line 315, in train_graph
    pred = self(data.x_dict, data.edge_index_dict, self.inspect(data, mode=set), data.edge_attr_dict)
  File "C:\ProgramData\Anaconda3\envs\gp2-plip\lib\site-packages\torch\nn\modules\module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "D:\Programmation\GP2-PLIP\src\model.py", line 218, in forward
    z_dict = self.encoder(x_dict, edge_index_dict)
  File "C:\ProgramData\Anaconda3\envs\gp2-plip\lib\site-packages\torch\fx\graph_module.py", line 630, in wrapped_call
    raise e.with_traceback(None)
RuntimeError: CUDA out of memory. Tried to allocate 4.21 GiB (GPU 0; 8.00 GiB total capacity; 128.69 MiB already allocated; 1.92 GiB free; 4.34 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

and I tried to check once more after the error above:

>>> torch.cuda.memory_allocated(0)
124257792

For further information, I am using pytorch v1.11.0, cudatoolkit v11.3.1 and pyg v2.1.0

I have no clue why such a large memory request is happening, and I can’t reduce the batch size as it is already equals to 1. Anyone would have an idea?

Thanks for the help, and wishing you a pleasant day!