Back propagation of model partition training on different mobile devices

I have DNN model, Layer {1, 2, 3, 4, 5}. I separate the model into two parts: fore part, layer {1, 2, 3} and hind part, layer {4, 5}. I deploy these two parts to different mobile devices. For example, fore part, layer {1, 2, 3} is deployed on device A and hind part, layer {4, 5} is deployed on device B. Training has two parts: forward pass and backward pass.

Forward pass: one mini-batch flows from layer 1 to layer 3 on device A. The intermediate output of layer 3 is sent to device B by wireless connection. Then, the intermediate output of layer 3 flows from layer 4 to layer 5 and loss function.

Backward pass: calculate the gradient and update the weight from layer 5 to layer 4 on mobile device B. The intermediate gradient of layer 4 is sent back to device A by wireless connection. Then, the mobile device A finishes the following back propagation.

I can realize the the forward pass. However, I does not understand how pytorch can realize the backward pass. For example, how the device B finishes the back propagation and get the intermediate gradient of layer 4. How the device A feeds intermediate gradient to layer 3 and finishes the back propagation.

Thank you for any suggestion and code example.