How does the training behave when half the model is wrapped with dataparallel?

In my understanding, the tricky part would be due to the batch norms (if any) used in the second part of the model, as the batch statistics will be calculated for each device separately.