GPU out of memory for simple RLHF

Hello everyone,

I am trying to implement a simple policy gradient for fine-tuning an autoregressive transformer from scratch, but got CUDA memory error issue when I try to use batch size 64 and have a long sequence (>100). The log_prob and log_prob take a lot of GPU memory, even though I have A100.

I was wondering if I did something wrong?

pg_total = []

for j in range(batch_size_rl):
    start_char = np.random.choice(list(aa_to_proba_first.keys()), p=list(aa_to_proba_first.values()))
    initial_context = torch.tensor([[stoi[start_char]]], dtype=torch.long, device=device)

    # Generating one sequence
    state, log_prob, log_prob_ref = model.generate_for_rlfh(idx=initial_context, max_new_tokens=100,
                                                            pad_token_idx=stoi["!"], ref_model=ref_model)

    reward = 1  # fake reward just for demonstrating the issue

    pg = (reward + ref_coef * log_prob_ref - e_coef * log_prob) * log_prob
    pg_total.append(pg)

pg_total = torch.stack(pg_total)
out = torch.mean(torch.sum(pg_total, dim=2))
actor_loss = -out

actor_optimizer.zero_grad(set_to_none=True)
actor_loss.backward()
actor_optimizer.step()

Thanks a lot!