How to circumvent need for retain_graph=True?


I am trying to train a GRU-based model. I encounter the following error when calling loss.backward():
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

Setting retain_graph=True allows the model to run without any errors for a few iterations. However, the memory just keeps on going up until such point that I reach the out of memory error.

Here is my network (adopted from How to Train your Deep Multi-Object Tracker by Xu et al.):

class Munkrs(nn.Module):
    def __init__(self, element_dim, hidden_dim, target_size, biDirenction, minibatch, is_cuda, is_train=True):
        super(Munkrs, self).__init__()
        self.hidden_dim = hidden_dim
        self.bidirect = biDirenction
        self.minibatch = minibatch
        self.is_cuda = is_cuda

        # The LSTM takes word embeddings as inputs, and outputs hidden states
        # with dimensionality hidden_dim.
        self.lstm_row = nn.GRU(element_dim, hidden_dim, bidirectional=biDirenction, num_layers=2)
        self.lstm_col = nn.GRU(hidden_dim*2, hidden_dim, bidirectional=biDirenction, num_layers=2)

        # The linear layer that maps from hidden state space to tag space
        if biDirenction:
            # *2 directions * 2 ways concat
            self.hidden2tag_1 = nn.Linear(hidden_dim * 2, hidden_dim)
            self.hidden2tag_2 = nn.Linear(hidden_dim, 64)
            print('target_size', target_size)
            self.hidden2tag_3 = nn.Linear(64, target_size)
            # * 2 ways concat
            self.hidden2tag_1 = nn.Linear(hidden_dim, target_size)

        self.hidden_row = self.init_hidden(1)
        self.hidden_col = self.init_hidden(1)

    def forward(self, Dt):

        # Dt is of shape [batch, h, w]
        # input_row is of shape [h*w, batch, 1], [time steps, mini batch, element dimension]
        # row lstm #

        input_row = Dt.view(Dt.size(0), -1, 1).permute(1, 0, 2).contiguous()
        lstm_R_out, self.hidden_row = self.lstm_row(input_row, self.hidden_row)

        # column lstm #
        # lstm_R_out is of shape [seq_len=h*w, batch, hidden_size * num_directions]

        # [h * w*batch, hidden_size * num_directions]
        lstm_R_out = lstm_R_out.view(-1, lstm_R_out.size(2))

        # [h * w*batch, 1]
        # lstm_R_out = self.hidden2tag_1(lstm_R_out).view(-1, Dt.size(0))

        # [h,  w, batch, hidden_size * num_directions]
        lstm_R_out = lstm_R_out.view(Dt.size(1), Dt.size(2), Dt.size(0), -1)

        # col wise vector
        # [w,  h, batch, hidden_size * num_directions]
        input_col = lstm_R_out.permute(1, 0, 2, 3).contiguous()
        # [w*h, batch, hidden_size * num_directions]
        input_col = input_col.view(-1, input_col.size(2), input_col.size(3)).contiguous()
        lstm_C_out, self.hidden_col = self.lstm_col(input_col, self.hidden_col)

        # undo col wise vector
        # lstm_out is of shape [seq_len=time steps=w*h, batch, hidden_size * num_directions]

        # [h, w, batch, hidden_size * num_directions]
        lstm_C_out = lstm_C_out.view(Dt.size(2), Dt.size(1), Dt.size(0), -1).permute(1, 0, 2, 3).contiguous()

        # [h*w*batch, hidden_size * num_directions]
        lstm_C_out = lstm_C_out.view(-1, lstm_C_out.size(3))

        # [h*w, batch, 1]
        tag_space = self.hidden2tag_1(lstm_C_out)
        tag_space = self.hidden2tag_2(tag_space)
        tag_space = self.hidden2tag_3(tag_space).view(-1, Dt.size(0))
        tag_scores = torch.sigmoid(tag_space)
        # tag_scores is of shape [batch, h, w] as Dt

        return_this = tag_scores.view(Dt.size(1), Dt.size(2), -1).permute(2, 0, 1).contiguous()
        return return_this

And here is my training loop:

for epoch in range(num_epochs):
	print('Epoch {}/{}'.format(epoch, num_epochs-1))
	print('-' * 10)

	for phase in ['train', 'val']:
		if phase == 'train':

		running_loss = 0.0

		for i, data in enumerate(dataloader[phase]):
			print('i', i)

			# get the inputs 
			x = data['input'].to(device)
			y_gt = data['label'].to(device)

			# zero the parameter gradients

			# forward, track history if only in train
			with torch.set_grad_enabled(phase == 'train'):
				y_pred = model(x)
				#print('input shape', x.shape)
				y_pred = y_pred.view(batch_size, x.shape[1]*x.shape[2])
				y_gt = y_gt.view(batch_size, x.shape[1]*x.shape[2])
				loss = focalLoss(y_pred, y_gt)

				if phase == 'train':

				with open(running_loss_log, 'a') as f:
					df = pd.DataFrame([[phase, epoch, i, loss]], columns=['split', 'epoch', 'i', 'loss'])
					df.to_csv(f, header=False)

			# print statistics
			with torch.no_grad():
				running_loss += loss

			if i % 10 == 9:
				print('[%d, %5d] Running loss: %.3f' % (epoch+1, i+1, running_loss/i))

		if phase == 'train':

		epoch_loss = running_loss / dataset_sizes[phase]
		print('{} Epoch Loss: {:.4f}'.format(phase, epoch_loss))

		with open(epoch_loss_log, 'a') as f:
			df = pd.DataFrame([[phase, epoch, epoch_loss]], columns=['split', 'epoch', 'loss'])
			df.to_csv(f, header=False)

		# deep copy the model according to loss
		if phase == 'val' and epoch_loss < best_loss:
			best_loss = epoch_loss
			best_model_wts = copy.deepcopy(model.state_dict())

	if epoch % 2 == 0:
		latest_model = copy.deepcopy(model.state_dict()), to_save + '_bestloss.pth'), to_save + 'latest.pth')

print('Finished Training')