Using data_parallel in libtorch with nested modules

I have been following the mnist_parallel example (19540)for using the data_parallel call. My module actually has nested modules, similar to how torchvision implemented GoogLeNet. The inception layers are actually a separate module. When I do this, the call to data_parallel crashes. If I move the inception (can be seen here) layer out into the main module, then the sequential call fails with something similar to

weight of size [16, 96, 1, 1], expected input[128, 16, 15, 15] to have 96 channels, but got 16 channels instead

The question is, how do I properly implement nested modules in libtorch when I am going to be making the call to data_parallel? I apologize for not posting the code for my net, its a bit big and ugly, but it started life very close to the torchvision GoogLeNet implementation.

Could you try to write a short proxy code, which could reproduce this issue?
I’m not sure where the error is coming from when similar models seem to work.

Apologies for the delay. I got it sorted out. The rules seem to be:

  1. You cannot you Sequential unless you plan to purposefully shard your model. It cannot be used in a model with a generic data_parallel call (where you replace your forward call with data_parallel). If used with that call you will get errors indicating that the weights are not on the correct GPU, like the one above.
  2. You may use nested modules, however they also need to be wrapped in Cloneable, just like the main module.