Hi everyone,
I am trying to fine-tune a T5 based model with reinforcement learning. I am currently using the T5 model to paraphrase a relatively short query before feeding it into my environment and retrieving my reward, and then utilize this reward to compute my losses for training.
I have scoured the forums and have tried to incorporate best practices like deleting variables once out of scope and storing my losses with loss.item()
. Even more strangely, I am able to train for a couple of batching with ever increasing memory usage as tracked by nvidia-smi
. I have even reduced by batch size to 1 but to no avail. I was able to reach the same number of batches trained with a higher batch size of 4 before meeting failure.
I am performing the training on an RTX 2080 ti with 11gb of RAM.
Does anyone know of potential causes of such behaviour? Thank you for taking the time to read this post.
Edit (include code):
Main Training Loop
for epoch in range(epochs):
acc_loss = 0.0
counter = 0
for i, batch in enumerate(train_dataloader):
queries = batch[
"question"
] # pass into step if using relative reward
contexts = batch["context"]
answers = batch["answer"]
del batch
output = agent.forward(queries)
del queries
reformulations = agent.decode_batch_sequence(
output.sequences.detach().cpu(), skip_special_tokens=True
)
rewards = world_model.step(reformulations, contexts, answers)
del contexts, answers, reformulations
rewards = torch.tensor(rewards, dtype=torch.float16).to(
agent.device
)
gen_probs, probs = agent.probabilities(output)
loss = agent.policy_loss(rewards, gen_probs, probs)
agent.optim.zero_grad()
loss.backward()
agent.optim.step()
loss = loss.detach().item()
wandb.log({"policy_loss": loss})
acc_loss += float(loss)
del loss
counter += 1
epoch_loss = acc_loss / counter
wandb.log({"epoch_loss": epoch_loss})
if epoch_loss < best_loss:
best_loss = epoch_loss
agent.transformer.save_pretrained(args.save_path)
Error:
RuntimeError: CUDA out of memory. Tried to allocate 2.00 MiB already allocated; 2.44 MiB free; 9.68 GiB reserved in total by PyTorch)
Further Updates:
After using the debug_memory
code snippet found here, it seems like my tensors are not being cleared after the optim.step() call, but rather freed when the computation of the loss for the next batch starts.