LSTM DataParallel hidden state not split between workers

Hello all,
I am trying to run a LSTM in DataParallel, and there have been several threads[1,2] which have mentioned that batch_first=True has to be enforced since Pytorch splits the data in first dimension.

  1. The problem is that the hidden state dimensions are still ( num_layers * num_directions x batch_size x hidden_size)
  2. This means it has to be manually permuted into ( batch_size x num_layers * num_directions x hidden_size) at initialization so that its split among workers along the batch dimension, and reshaped back to ( num_layers * num_directions x batch_size x hidden_size) as input to forward(), to keep the hidden state dimensions consistent with LSTM documentation.

This has been discussed in quite a few threads, and it does not seem to have any clear, clean resolution as of now

[1] Multi layer RNN with DataParallel
[2] DataParallel LSTM/GRU wrong hidden batch size (8 GPUs)

Going by the strategies in these threads, I have implemented the code as below - BUT Pytorch does not split the hidden state at all, regardless of the dimensions.


class LSTM(nn.Module):

    def __init__(self,input_size,hiddensize,batch,output_size,nlayers):
        super(LSTM, self).__init__()
        self.input_size= input_size
        self.hiddensize= hiddensize
        self.batch= batch
        self.output_size= output_size
        self.nlayers= nlayers

        #Define LSTM layers
        self.lstm= nn.LSTM(self.input_size, self.hiddensize, self.nlayers, batch_first=True)
        #Define Output layer
        self.linear= nn.Linear(self.hiddensize,self.output_size)

    def init_hidden(self):
        h0= torch.randn(self.batch, self.nlayers, self.hiddensize).cuda() #batch_first= True
        c0= torch.randn(self.batch, self.nlayers, self.hiddensize).cuda()
        return (h0,c0)

    def forward(self,input_tensor):
        hiddenvecs= self.init_hidden()
        hiddenvecs = tuple([h.permute(1,0,2).contiguous() for h in hiddenvecs])
        lstm_out, hiddenvecs = self.lstm(input_tensor, hiddenvecs)
        #output layer
        pred= self.linear(lstm_out)
        return pred

#LSTM params
nlayers=3
hiddensize= 40
input_size=1
output_size=input_size
batch=8
epochs= 1000000
k=10 #sequence length

# load batches and ensure nbatches is divisible by batch size
x = data[:,:-1,:]
y = data[:,1:,:]
nbatch_idx = int(x.size(1)/batch)*batch
x = x[:,:nbatch_idx,:]
y = y[:,:nbatch_idx,:]
x = torch.reshape(x, (x.size(1),k,input_size)) #batch first=True
y = torch.reshape(y, (y.size(1),k,input_size))
# Create dataloader pipeline
train_dataset = torch.utils.data.TensorDataset(x,y)
trainloader = DataLoader(train_dataset, batch_size=batch, shuffle=False, num_workers=0)
print(len(trainloader))


model= LSTM(input_size,hiddensize,batch,output_size,nlayers)
if torch.cuda.device_count() > 1:
    print('Using', torch.cuda.device_count(), 'GPUs')
    model = nn.DataParallel(model)
model.cuda()

print('Start training...')
PATH = 'checkpoints'
loss_history=[]
for ep in range(epochs):
    print('Epoch', ep)
    for local_x, local_y in tqdm(trainloader, total=len(trainloader)):
        local_x, local_y = local_x.cuda(), local_y.cuda()
        output = model(local_x.type(torch.cuda.FloatTensor)).cuda()
        loss = criterion(output, local_y.type(torch.cuda.FloatTensor))
        # #backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_history.append(loss.item())

size of input_tensor and target i.e. x and y, are (batch x sequence_length X input_size). I have init the hidden state with batch first, and permuted them to batch second, for LSTM forward. Here is the error message I get

RuntimeError: Expected hidden[0] size (3, 2, 40), got (3, 8, 40)

So the hidden state of dimension (8,3,40) should have been split into (2,3,40) - I am using 4 GPUs (so 2 x 4 = 8), followed by my permute operation should have given (3,2,40) - which is exactly what the network needs. But the error above says the split never happened.

Regardless of how I reshape the hidden state (batch first or second) , the batch dimension does not get split among the workers. I would really appreciate any help with whats going on. Thank you!

I “might” have found a workaround for this issue, or maybe its the actual correct way to implement. According to the torch.nn.LSTM docs
“If (h_0, c_0) is not provided, both h_0 and c_0 default to zero.”

So the workaround is basically to allow nn.LSTM to initialize itself rather than have separate init_hidden logic. Some might say this is the correct way to initialize the hidden. Thoughts? I see better results initializing the RNN hidden with zeros anyway.

No need to even init hidden as the LSTM will do it automatically. Your forward should look something like this:

def forward(self, x):
out, hidden = self.lstm(x)
out = self.fc(out)
return out

1 Like

Thanks for the response - but I do want to explicitly define the hidden states (I plan to develop a custom LSTM based architecture) and initialize them every time, and thats when the dataparallel doesn’t split along the batch. Is there a way to force Pytorch to do this?

1 Like

I need to do exactly the same. Any update from other threads?
My code is getting kind of messy dealing with batch_first in the RNN instance, dim=1 in the nn.DataParallel and all the functions to have the batch dimension of every object in the same position …

Hi Davide,
I ended up implementing LSTM from scratch in Pytorch as it gave me greater flexibility to try various modifications and not mess with the Pytorch internals. It didn’t take much time though.

But once you have an implementation, you can still automatically run it in parallel/distributed mode using Pytorch’s built in modules. Then, to explicitly define hidden states at every call, have the LSTM call the init method like this:

    def init_states(self, hiddensize, distbatchsize):
        """ Initialize h and c LSTMCell states """
        h = torch.randn(distbatchsize, hiddensize, self.input_dim[0], self.input_dim[1]).cuda()
        c = torch.randn(distbatchsize, hiddensize, self.input_dim[0], self.input_dim[1]).cuda()
        states = (h,c)
        return states

In data parallel mode, there is a separate network on each GPU, so this function is called locally on each GPU which has its local batch size. In that case, the hidden states should then be initialized with that local batch size.

The “distbatchsize” here does exactly that. It is the batch size local to the GPU after the full batch size has been scattered by the distributed module. So if your full batchsize = 8 and num_gpus = 4, then distbatchsize = 8/4 = 2. Pytorch automatically computes this and splits your input tensors into that size and sends it to all GPUs (Check the all_scatter module in MPI for more details). So How do you find the distbatchsize? The snippet below is from the LSTM forward method. The first dimension is the local batch size = 2. If you were running this on only 1 GPU, the distbatchsize = 8 i.e. the global batch size.

    def forward(self, input_tensor):
        """ Forward Pass """

        # Find the local batch size
        distbatchsize = input_tensor.size(0)

       # continue with init_states and forward pass

Not a clean solution, but it serves me well. Let me know if you have any comments!

Thank you Arvind!
It’s a great solution, will try it out soon!
I need to initialise hidden states querying a big tensor that I would avoid coping to each gpu.
Although pretty inefficient I think I have to send the query vector to where the big tensor is, extracting the query and then call .cuda() again. Not sure if autograd can track these device changes… hope to give a feedback soon