Pytorch with DDP throws error with multi-GPU


I am trying to train a multi-task BERT model with PyTorch on a single node with 8GPUs. I am able to train the model if I use a single GPU, however if I switch to multiple GPUs I get an error with a mismatch in batch size. I checked in debug mode that in the forward function (from the validation step) if I specify a batch size of 4 with 2 GPUs, I get 2 input_ids in the batch. If I specify a batch size of 4 with 1 GPU, I get 4 input_ids in batch which right.
My forward function looks like pasted below

        batch = x

        bert_output = self.bert(
        sequence_output = bert_output["last_hidden_state"]
        pooled_output = bert_output["pooler_output"]

        masked_seq_indices = np.where([x for x in batch["token_tasks"].values()][0])[0]
        chunk_permuted_seq_indices = np.where([x for x in batch["sequence_tasks"].values()][0])[0]

        mlm_scores = self.lm(sequence_output[masked_seq_indices])
        mlm_loss = self.mlm_loss(mlm_scores.view(-1, self.vocab_size), batch["mlm_labels"].view(-1))

        chunk_permute_scores = self.sent["chunk_perm"](pooled_output[chunk_permuted_seq_indices])
        chunk_perm_loss = self.chunk_perm_loss(chunk_permute_scores, batch["chunk_permuted_labels"])

        return (mlm_loss, chunk_perm_loss)````

What parameters I need to add in my forward function or elsewhere to get the right batch when evaluating the validation step (and training step as well)?

Thank you!