Hello everyone,
I am trying to implement a simple policy gradient for fine-tuning an autoregressive transformer from scratch, but got CUDA memory error issue when I try to use batch size 64 and have a long sequence (>100). The log_prob and log_prob take a lot of GPU memory, even though I have A100.
I was wondering if I did something wrong?
pg_total = []
for j in range(batch_size_rl):
start_char = np.random.choice(list(aa_to_proba_first.keys()), p=list(aa_to_proba_first.values()))
initial_context = torch.tensor([[stoi[start_char]]], dtype=torch.long, device=device)
# Generating one sequence
state, log_prob, log_prob_ref = model.generate_for_rlfh(idx=initial_context, max_new_tokens=100,
pad_token_idx=stoi["!"], ref_model=ref_model)
reward = 1 # fake reward just for demonstrating the issue
pg = (reward + ref_coef * log_prob_ref - e_coef * log_prob) * log_prob
pg_total.append(pg)
pg_total = torch.stack(pg_total)
out = torch.mean(torch.sum(pg_total, dim=2))
actor_loss = -out
actor_optimizer.zero_grad(set_to_none=True)
actor_loss.backward()
actor_optimizer.step()
Thanks a lot!