I have created seq2seq model and it has encoder and decoder.
The encoder has 2 main “blocks”: embedding layer (
nn.Embedding) and Lstm (
Lstm has a lot of parameters and GPU can’t handle even this one block so I can’t use
What should I do? How can I distribute memory through 8 GPUs in a case of RNN?
P.S Also, want to say that Pytorch is awesome, thank you for developing it
I don’t know much about seq2seq, but maybe this is a good reference:
You can manually ship different parts of your model to different GPUs. For example
self.m1 = nn.Linear(10, 10)
self.m2 = nn.Linear(10, 10)
self.m3 = nn.Linear(10, 10)
self.m1.cuda(0) # puts in GPU0
self.m2.cuda(1) # puts in GPU1
self.m3.cuda(2) # puts in GPU2
def forward(self, x):
x = self.m1(x.cuda(0))
x = self.m2(x.cuda(1))
x = self.m3(x.cuda(2))
But in my case, one gpu can’t handle some of the layers.
for example, I would not be able to use
self.m1.cuda(0) because it would be “out of memory”.
For example, my Lstm “layer”(
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)) is so much big that I can’t refer if to only one GPU
In this case you might need to adapt the LSTM code in here to handle your huge model.
You can have a look at here for hints where you could eventually split your model between different GPUs
Actually, you might need to use LSTMCell for that, as what I pointed out earlier will probably not be enough.
I don’t have any experience with LSTMs though, so I won’t be able to guide you much further.
Thanks, I will look at this.
Maybe I can use
nn.parallel.data_parallel somehow for my big Lstm “layer”?
Not directly, as
DataParallel is only for splitting your input data into many GPUs. But you can always split a huge tensor in several GPUs, something like
# huge tensor is M x K, split in 2 Tensors
hugeTensor = torch.rand(1000, 100)
small_tensors_on_different_gpus = [tensor.gpu(gpu_id) for gpu_id, tensor in
I think you mean
class DataParallel but I mean
def data_parallel link
Or I misunderstand you?
Both are very similar actually, the difference is functional / module interface