Can we split a large, pytorch-built-in `nn.Module` to multiple-gpu?

eg. I have and nn.Linear whose in_features = 500e4 and out_features = 3000, so number of trainable parameters consume 55 GB (500e4 * 3000 * 4 / 1024 ** 2) memory. My single GPU has 12 GB and I have 8 GPU on one machine. How can I parallel this “atomic” built-in module on multiple gpu?


I’m afraid we don’t provide any construct to do this automatically.
But you can simply create 8 different Linear that each take a subset of the input and split the input yourself and call each of these Linears and then add all the results (assuming your split on the input size here given that it is the biggest).

1 Like