Code is getting struck due to async between the gpus in distributed setup

Hello,

I have a setup of finetuning the LLM and I need to fit the entire input into the context which is kind of not possible because of constraints in memory and context length. So I am trying to do inference in sequential manner where I am sending the parts of the input sequentially and calculate the loss. But the problem is one gpu finishes the inference faster than the second gpu because of difference in length of the sequences and the code is getting struck at dist.barrier(). Any tips on how to make the gpu-1 wait for the second one? I have attached the small snippet of my code below which has the relavent logic

                    policy_chosen_sum.zero_()
                    policy_rejected_sum.zero_()
                    reference_chosen_sum.zero_()
                    reference_rejected_sum.zero_()

                    reg_index=random.randint(0,len(input_ids)-1)

                    for index in range(len(input_ids)):
                        log_policy_probs, policy_logits = self.concatenated_forward(
                            self._model, input_ids[index], labels[index]
                        )
                        if index==reg_index:
                            sft_policy_logits=policy_logits
                            sft_policy_labels=labels[index]
                        del policy_logits

                        with torch.no_grad(), disable_adapter(self._model):
                            reference_log_probs, reference_logits = self.concatenated_forward(
                                self._model, input_ids[index], labels[index]
                            )

                            del reference_logits

                        if index < ratio[0]:
                            policy_chosen_sum += log_policy_probs
                            reference_chosen_sum += reference_log_probs
                        else:
                            policy_rejected_sum += log_policy_probs
                            reference_rejected_sum += reference_log_probs

                    loss, chosen_rewards, rejected_rewards = self._loss_fn(
                            policy_chosen_sum,
                            policy_rejected_sum,
                            reference_chosen_sum,
                            reference_rejected_sum,
                    )

            
                   
                    loss = loss.mean()
                    reward_accuracy = (chosen_rewards > rejected_rewards).float().mean().cpu()
 
                    running_val_loss += loss 
                    running_reward_accuracy += reward_accuracy

                    pbar_val.update(1)
                    pbar_val.set_description(
                        f"{self.epochs_run+1}|{self.global_step}|Validation Loss: {running_val_loss / (idx + 1)}"
                    )
                    idx += 1

                mean_val_loss = running_val_loss / (idx + 1)
                mean_reward_accuracy = running_reward_accuracy / (idx + 1)
                
                dist.barrier()

I tried to increase the “init_process_group” timeout parameter but one thing I observed is that the second gpu is very very slow as soon as the first gpu reaches the dist.barrier(). Before that, they were very close and both are performing inferences in rellatively shorter time .