Hi,
I am trying to train a GRU-based model. I encounter the following error when calling loss.backward()
:
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
Setting retain_graph=True
allows the model to run without any errors for a few iterations. However, the memory just keeps on going up until such point that I reach the out of memory
error.
Here is my network (adopted from How to Train your Deep Multi-Object Tracker by Xu et al.):
class Munkrs(nn.Module):
def __init__(self, element_dim, hidden_dim, target_size, biDirenction, minibatch, is_cuda, is_train=True):
super(Munkrs, self).__init__()
self.hidden_dim = hidden_dim
self.bidirect = biDirenction
self.minibatch = minibatch
self.is_cuda = is_cuda
# The LSTM takes word embeddings as inputs, and outputs hidden states
# with dimensionality hidden_dim.
self.lstm_row = nn.GRU(element_dim, hidden_dim, bidirectional=biDirenction, num_layers=2)
self.lstm_col = nn.GRU(hidden_dim*2, hidden_dim, bidirectional=biDirenction, num_layers=2)
# The linear layer that maps from hidden state space to tag space
if biDirenction:
# *2 directions * 2 ways concat
self.hidden2tag_1 = nn.Linear(hidden_dim * 2, hidden_dim)
self.hidden2tag_2 = nn.Linear(hidden_dim, 64)
print('target_size', target_size)
self.hidden2tag_3 = nn.Linear(64, target_size)
else:
# * 2 ways concat
self.hidden2tag_1 = nn.Linear(hidden_dim, target_size)
self.hidden_row = self.init_hidden(1)
self.hidden_col = self.init_hidden(1)
def forward(self, Dt):
# Dt is of shape [batch, h, w]
# input_row is of shape [h*w, batch, 1], [time steps, mini batch, element dimension]
# row lstm #
input_row = Dt.view(Dt.size(0), -1, 1).permute(1, 0, 2).contiguous()
lstm_R_out, self.hidden_row = self.lstm_row(input_row, self.hidden_row)
# column lstm #
# lstm_R_out is of shape [seq_len=h*w, batch, hidden_size * num_directions]
# [h * w*batch, hidden_size * num_directions]
lstm_R_out = lstm_R_out.view(-1, lstm_R_out.size(2))
# [h * w*batch, 1]
# lstm_R_out = self.hidden2tag_1(lstm_R_out).view(-1, Dt.size(0))
# [h, w, batch, hidden_size * num_directions]
lstm_R_out = lstm_R_out.view(Dt.size(1), Dt.size(2), Dt.size(0), -1)
# col wise vector
# [w, h, batch, hidden_size * num_directions]
input_col = lstm_R_out.permute(1, 0, 2, 3).contiguous()
# [w*h, batch, hidden_size * num_directions]
input_col = input_col.view(-1, input_col.size(2), input_col.size(3)).contiguous()
lstm_C_out, self.hidden_col = self.lstm_col(input_col, self.hidden_col)
# undo col wise vector
# lstm_out is of shape [seq_len=time steps=w*h, batch, hidden_size * num_directions]
# [h, w, batch, hidden_size * num_directions]
lstm_C_out = lstm_C_out.view(Dt.size(2), Dt.size(1), Dt.size(0), -1).permute(1, 0, 2, 3).contiguous()
# [h*w*batch, hidden_size * num_directions]
lstm_C_out = lstm_C_out.view(-1, lstm_C_out.size(3))
# [h*w, batch, 1]
tag_space = self.hidden2tag_1(lstm_C_out)
tag_space = self.hidden2tag_2(tag_space)
tag_space = self.hidden2tag_3(tag_space).view(-1, Dt.size(0))
tag_scores = torch.sigmoid(tag_space)
# tag_scores is of shape [batch, h, w] as Dt
return_this = tag_scores.view(Dt.size(1), Dt.size(2), -1).permute(2, 0, 1).contiguous()
return return_this
And here is my training loop:
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs-1))
print('-' * 10)
for phase in ['train', 'val']:
if phase == 'train':
model.train()
else:
model.eval()
running_loss = 0.0
for i, data in enumerate(dataloader[phase]):
model.detach_()
print('i', i)
# get the inputs
x = data['input'].to(device)
y_gt = data['label'].to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward, track history if only in train
with torch.set_grad_enabled(phase == 'train'):
y_pred = model(x)
#print('input shape', x.shape)
y_pred = y_pred.view(batch_size, x.shape[1]*x.shape[2])
y_gt = y_gt.view(batch_size, x.shape[1]*x.shape[2])
loss = focalLoss(y_pred, y_gt)
if phase == 'train':
loss.backward()
optimizer.step()
with open(running_loss_log, 'a') as f:
print(running_loss)
df = pd.DataFrame([[phase, epoch, i, loss]], columns=['split', 'epoch', 'i', 'loss'])
df.to_csv(f, header=False)
# print statistics
with torch.no_grad():
running_loss += loss
if i % 10 == 9:
print('[%d, %5d] Running loss: %.3f' % (epoch+1, i+1, running_loss/i))
if phase == 'train':
exp_lr_scheduler.step()
epoch_loss = running_loss / dataset_sizes[phase]
print('{} Epoch Loss: {:.4f}'.format(phase, epoch_loss))
with open(epoch_loss_log, 'a') as f:
df = pd.DataFrame([[phase, epoch, epoch_loss]], columns=['split', 'epoch', 'loss'])
df.to_csv(f, header=False)
# deep copy the model according to loss
if phase == 'val' and epoch_loss < best_loss:
best_loss = epoch_loss
best_model_wts = copy.deepcopy(model.state_dict())
if epoch % 2 == 0:
latest_model = copy.deepcopy(model.state_dict())
torch.save(best_model_wts, to_save + '_bestloss.pth')
torch.save(latest_model, to_save + 'latest.pth')
print('Finished Training')