I have DNN model, Layer {1, 2, 3, 4, 5}. I separate the model into two parts: fore part, layer {1, 2, 3} and hind part, layer {4, 5}. I deploy these two parts to different mobile devices. For example, fore part, layer {1, 2, 3} is deployed on device A and hind part, layer {4, 5} is deployed on device B. Training has two parts: forward pass and backward pass.
Forward pass: one mini-batch flows from layer 1 to layer 3 on device A. The intermediate output of layer 3 is sent to device B by wireless connection. Then, the intermediate output of layer 3 flows from layer 4 to layer 5 and loss function.
Backward pass: calculate the gradient and update the weight from layer 5 to layer 4 on mobile device B. The intermediate gradient of layer 4 is sent back to device A by wireless connection. Then, the mobile device A finishes the following back propagation.
I can realize the the forward pass. However, I does not understand how pytorch can realize the backward pass. For example, how the device B finishes the back propagation and get the intermediate gradient of layer 4. How the device A feeds intermediate gradient to layer 3 and finishes the back propagation.
Thank you for any suggestion and code example.