What I know about the problem
- Adam is stateful and requires a memory space proportional to the parameters in your model.
- Model parameters must be loaded onto device 0
- OOM occurs at state[‘exp_avg_sq’] = torch.zeros_like(p.data) which seems to be the last allocation of memory in the optimizer source code.
- Neither manually allocating or use of nn.DataParallel prevents OOM error
- Moved loss to forward function to reduce memory in training
Below are my training and forward methods
def train(dataloader, vocabulary_dict, epoch):
model = ViralClassification(len(vocabulary_dict), 0.5, 6588 )#, device_ids=[0,1], output_device=1)
model.to('cuda:0')
model.print_gpu_memory_info()
optimizer = optim.Adam(model.parameters(), lr=0.01)
cudnn.benchmark = True
cudnn.enabled = True
for i in range(epoch):
running_loss = 0.0
for (i, (label, sequence)) in enumerate(dataloader):
loss = model(sequence.to('cuda:0'), label.to('cuda:1'))#.to('cuda:0'))
running_loss += loss.item()
loss.backward()
del loss
print('gpu_memory_one in training_loop')
model.print_gpu_memory_info()
torch.cuda.empty_cache()
print('gpu_memory_two in training_loop')
model.print_gpu_memory_info()
print('bout to step')
optimizer.step()
optimizer.zero_grad()
print('bottom of training loop')
def forward(self, inputs, labels):
inputs = self.embedding(inputs)#.to('cuda:0')
#print('embedded completed')
(inputs, hidden_state) = self.bilstm_layer(inputs)
#print('bilstm completed')
inputs.to('cuda:1')
self.bilstm_layer.flatten_parameters()
torch.cuda.synchronize('cuda:0')
torch.cuda.synchronize('cuda:1')
inputs = self.attention_layer(inputs)
#print('attention completed')
inputs = inputs.view(-1, self.row*2*self.lstm_dim)#.to('cuda:1')
#print('view transformation completed')
inputs = self.mlp_one(inputs, self.relu_one)
#print('mlp one completed')
inputs = self.mlp_two(inputs, self.relu_two)
#print('mlp two completed')
torch.cuda.synchronize('cuda:0')
torch.cuda.synchronize('cuda:1')
logits = self._classify(inputs)
#print('logits completed')
torch.cuda.empty_cache()
torch.cuda.synchronize('cuda:0')
torch.cuda.synchronize('cuda:1')
#self.print_gpu_memory_info()
loss = self.criterion(logits, labels)
return loss
My OOM occurs when I perform optimizer.step.
My problem is that before optimizer.step my memory on device 1 has plenty of open room but since the optimizer performs it calculations on device 0, the OOM occurs.
Is this a problem that checkpointing may be able to solve?
Is it possible to change the location of the optimizer?