Optimize Training Speed for Recommendation System Using Subnetworks

I am building a recommendation system inspired by YouTube’s “Deep Neural Networks for YouTube Recommendations” paper. I will need to execute recommendations in real time so I structured it with low latency predictions in mind. The structure is the following

      |User Features|                         |Item Features|
               |                                   |
|Fully Connected NN_user|              |Fully Connected NN_item|
                \                        /
              |Concatenated output of both NNs|
                              |
                     |Fully Connected NN|
                              |
                           |output|

This is all one network built using two sub-networks.

The reason I did it this way is to create rich embeddings for the user and item based on their features which I could then store. At prediction time, I can retrieve the stored embeddings, then only the top NN needs to be executed and is therefore very fast. In testing, the model gives good results.

My question is about decreasing the time it takes to train this model. Is there a way for Pytorch to execute the sub-networks in parallel? Using DataParallel splits that data and trains it in parallel, but I think that the two sub-NN are trained one after the other, even though they don’t need to be. The forward section of the model has the following structure:

def sub-network(features, **params):
      ....

def forward(user_features, item_features):
      user_embedding = sub-network(user_features)
      item_embedding = sub-network(item_features)
      x = torch.cat([user_embedding, item_embedding],1)
      ...

What is a good strategy for parallelizing the execution of the sub-network functions?