Optimize Training Speed for Recommendation System Using Subnetworks

I am building a recommendation system inspired by YouTube’s “Deep Neural Networks for YouTube Recommendations” paper. I will need to execute recommendations in real time so I structured it with low latency predictions in mind. The structure is the following

      |User Features|                         |Item Features|
               |                                   |
|Fully Connected NN_user|              |Fully Connected NN_item|
                \                        /
              |Concatenated output of both NNs|
                              |
                     |Fully Connected NN|
                              |
                           |output|

This is all one network built using two sub-networks.

The reason I did it this way is to create rich embeddings for the user and item based on their features which I could then store. At prediction time, I can retrieve the stored embeddings, then only the top NN needs to be executed and is therefore very fast. In testing, the model gives good results.

My question is about decreasing the time it takes to train this model. Is there a way for Pytorch to execute the sub-networks in parallel? Using DataParallel splits that data and trains it in parallel, but I think that the two sub-NN are trained one after the other, even though they don’t need to be. The forward section of the model has the following structure:

def sub-network(features, **params):
      ....

def forward(user_features, item_features):
      user_embedding = sub-network(user_features)
      item_embedding = sub-network(item_features)
      x = torch.cat([user_embedding, item_embedding],1)
      ...

What is a good strategy for parallelizing the execution of the sub-network functions?

You do have several GPUs to make this worthwile, right?
Given the asynchronous nature of of GPU computation, you can just move one network and inputs to the second GPU. Then it will be queued serially, but executed in parallel. Just be sure to not introduce sync points.
Or you could look at the multiprocessing best practices for advice.

Best regards

Thomas