I am trying to run a LSTM in DataParallel, and there have been several threads[1,2] which have mentioned that batch_first=True has to be enforced since Pytorch splits the data in first dimension.
- The problem is that the hidden state dimensions are still ( num_layers * num_directions x batch_size x hidden_size)
- This means it has to be manually permuted into ( batch_size x num_layers * num_directions x hidden_size) at initialization so that its split among workers along the batch dimension, and reshaped back to ( num_layers * num_directions x batch_size x hidden_size) as input to forward(), to keep the hidden state dimensions consistent with LSTM documentation.
This has been discussed in quite a few threads, and it does not seem to have any clear, clean resolution as of now
Going by the strategies in these threads, I have implemented the code as below - BUT Pytorch does not split the hidden state at all, regardless of the dimensions.
class LSTM(nn.Module): def __init__(self,input_size,hiddensize,batch,output_size,nlayers): super(LSTM, self).__init__() self.input_size= input_size self.hiddensize= hiddensize self.batch= batch self.output_size= output_size self.nlayers= nlayers #Define LSTM layers self.lstm= nn.LSTM(self.input_size, self.hiddensize, self.nlayers, batch_first=True) #Define Output layer self.linear= nn.Linear(self.hiddensize,self.output_size) def init_hidden(self): h0= torch.randn(self.batch, self.nlayers, self.hiddensize).cuda() #batch_first= True c0= torch.randn(self.batch, self.nlayers, self.hiddensize).cuda() return (h0,c0) def forward(self,input_tensor): hiddenvecs= self.init_hidden() hiddenvecs = tuple([h.permute(1,0,2).contiguous() for h in hiddenvecs]) lstm_out, hiddenvecs = self.lstm(input_tensor, hiddenvecs) #output layer pred= self.linear(lstm_out) return pred #LSTM params nlayers=3 hiddensize= 40 input_size=1 output_size=input_size batch=8 epochs= 1000000 k=10 #sequence length # load batches and ensure nbatches is divisible by batch size x = data[:,:-1,:] y = data[:,1:,:] nbatch_idx = int(x.size(1)/batch)*batch x = x[:,:nbatch_idx,:] y = y[:,:nbatch_idx,:] x = torch.reshape(x, (x.size(1),k,input_size)) #batch first=True y = torch.reshape(y, (y.size(1),k,input_size)) # Create dataloader pipeline train_dataset = torch.utils.data.TensorDataset(x,y) trainloader = DataLoader(train_dataset, batch_size=batch, shuffle=False, num_workers=0) print(len(trainloader)) model= LSTM(input_size,hiddensize,batch,output_size,nlayers) if torch.cuda.device_count() > 1: print('Using', torch.cuda.device_count(), 'GPUs') model = nn.DataParallel(model) model.cuda() print('Start training...') PATH = 'checkpoints' loss_history= for ep in range(epochs): print('Epoch', ep) for local_x, local_y in tqdm(trainloader, total=len(trainloader)): local_x, local_y = local_x.cuda(), local_y.cuda() output = model(local_x.type(torch.cuda.FloatTensor)).cuda() loss = criterion(output, local_y.type(torch.cuda.FloatTensor)) # #backward optimizer.zero_grad() loss.backward() optimizer.step() loss_history.append(loss.item())
size of input_tensor and target i.e. x and y, are (batch x sequence_length X input_size). I have init the hidden state with batch first, and permuted them to batch second, for LSTM forward. Here is the error message I get
RuntimeError: Expected hidden size (3, 2, 40), got (3, 8, 40)
So the hidden state of dimension (8,3,40) should have been split into (2,3,40) - I am using 4 GPUs (so 2 x 4 = 8), followed by my permute operation should have given (3,2,40) - which is exactly what the network needs. But the error above says the split never happened.
Regardless of how I reshape the hidden state (batch first or second) , the batch dimension does not get split among the workers. I would really appreciate any help with whats going on. Thank you!