Shared weights with model parallelism

Our setup involves initial part of the network (input interface) which run on separate GPU cards. Each GPU gets its own portion of data (model parallelism) and process it separately.

Each input interface, in turn, it itself a complex nn.Module. Every input interface can occupy one or several cards (say, interface_1 runs on GPU 0 and 1, interface_2 - on GPU 2 and 3 and so on).

We need to keep the weights of these input interface the same all over the training. We also need them to run in parallel to save training time which is already weeks for our scenario.

The best idea we can think of was initializing the interfaces with the same weights and then average the gradients for them. As the interfaces are identical, updating same weights with the same gradients should keep them the same all over the training process thus achieving desired “shared weights” mode.

However, I cannot find any good solution for changing values of these weights and their gradients represented as Parameter in PyTorch. Apparently, PyTorch does not allow to do so.

Our current state is: if we copy.deepcopy the ‘parameter.data’ of the “master” interface and assign it to ‘parameter.data’ of the “slave” interface, the values are indeed changed but .to(device_id) does not work and keeps them at the “master” device. However, we need them to move to a “slave” device.

Could someone please tell me if it is possible at all or, if not, if there is a better way to implement shared weights along with the parallel execution for our scenario?

Based on your description, I assume some parallel methods could probably help in your use case.
The usage of the .data attribute is not recommended and you could try to adapt some synchronization work flows from DDP.

You are right, the whole system is the variation of model parallelism. We were also thinking about data parallelism across the interfaces but that does not work as in this case PyTorch puts all input data on GPU 0 before distributing it across parallel nodes. For our scenario, the input data is so large that it takes more than entire 32Gb of GPU 0 memory itself.

In that case, try to also stick to the DDP approach (I would generally recommend DDP instead of DataParallel).