How to fix torch error RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time

Hi,
I got the error when I use .backward(), but I don’t know how to fix it, could anybody help me, thanks a lot.
here is my train code:

def train(config, model, train_iter, dev_iter, test_iter):
    start_time = time.time()
    model.train()
    optimizer = torch.optim.Adadelta(model.parameters(),
                                     lr=config.learning_rate)

    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
    loss_fct = CrossEntropyLoss()
    total_batch = 1  
    dev_best_loss = float('inf')
    last_improve = 0  
    flag = False  
    writer = SummaryWriter(log_dir=config.log_path + '/' +
                           time.strftime('%m-%d_%H.%M', time.localtime()))
    for epoch in range(config.num_epochs):
        print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs))
        for _, trains in enumerate(train_iter):
            model.zero_grad()
            left_trains, right_trains, (mid_trains, labels) = trains
            outputs = model(left_trains.cuda(config.cuda_id),
                            right_trains.cuda(config.cuda_id),
                            mid_trains.cuda(config.cuda_id))
            loss = loss_fct(outputs, labels.cuda(config.cuda_id))
            loss.backward(retain_graph=True)
            optimizer.step()
            optimizer.zero_grad()
            if total_batch % 10 == 0:
                true = labels.data.cpu()
                predic = torch.max(outputs.data, 1)[1].cpu()
                train_acc = metrics.accuracy_score(true, predic)
                dev_acc, dev_loss = evaluate(config, model, dev_iter)
                if dev_loss < dev_best_loss:
                    dev_best_loss = dev_loss
                    torch.save(model.state_dict(), config.save_path)
                    improve = '*'
                    last_improve = total_batch
                else:
                    improve = ''
                time_dif = get_time_dif(start_time)
                msg = 'Iter: {0:>6},  Train Loss: {1:>5.2},  Train Acc: {2:>6.2%},  Val Loss: {3:>5.2},  Val Acc: {4:>6.2%},  Time: {5} {6}'
                print(
                    msg.format(total_batch, loss.item(), train_acc, dev_loss,
                               dev_acc, time_dif, improve))
                writer.add_scalar("loss/train", loss.item(), total_batch)
                writer.add_scalar("loss/dev", dev_loss, total_batch)
                writer.add_scalar("acc/train", train_acc, total_batch)
                writer.add_scalar("acc/dev", dev_acc, total_batch)
                model.train()
            total_batch += 1
            if total_batch - last_improve > config.require_improvement:
                print("No optimization for a long time, auto-stopping...")
                flag = True
                break
        if flag:
            break
        scheduler.step()
    writer.close()
    test(config, model, test_iter)


def test(config, model, test_iter):
    # test
    model.load_state_dict(torch.load(config.save_path))
    model.eval()
    start_time = time.time()
    test_acc, test_loss, test_report, test_confusion = evaluate(config,
                                                                model,
                                                                test_iter,
                                                                test=True)
    msg = 'Test Loss: {0:>5.2},  Test Acc: {1:>6.2%}'
    print(msg.format(test_loss, test_acc))
    print("Precision, Recall and F1-Score...")
    print(test_report)
    print("Confusion Matrix...")
    print(test_confusion)
    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)


def evaluate(config, model, data_iter, test=False):
    model.eval()
    loss_total = 0
    predict_all = np.array([], dtype=int)
    labels_all = np.array([], dtype=int)
    with torch.no_grad():
        for evals in data_iter:
            left_evals, right_evals, mid_data = evals
            mid_evals, labels = mid_data
            labels = labels.cuda(config.cuda_id)
            outputs = model(left_evals.cuda(config.cuda_id),
                            right_evals.cuda(config.cuda_id),
                            mid_evals.cuda(config.cuda_id))
            loss = F.cross_entropy(outputs, labels)
            loss_total += loss
            labels = labels.data.cpu().numpy()
            predic = torch.max(outputs.data, 1)[1].cpu().numpy()
            labels_all = np.append(labels_all, labels)
            predict_all = np.append(predict_all, predic)

    acc = metrics.accuracy_score(labels_all, predict_all)
    if test:
        class_list = [int(i) for i in config.class_list.split()]
        report = metrics.classification_report(labels_all,
                                               predict_all,
                                               target_names=class_list,
                                               digits=4)
        confusion = metrics.confusion_matrix(labels_all, predict_all)
        return acc, loss_total / len(data_iter[2]), report, confusion
    return acc, loss_total / len(data_iter[2])

here is my model:

class LR_CNN(nn.Module):
    """
    Input shape:
        4D tensor with shape: (batch, K, pad_size, embed_dim)
     Output shape
        3D tensor with shape: (batch, K, len(filter_sizes)*num_filters)
    """
    def __init__(self, config) -> None:
        super(LR_CNN, self).__init__()
        filter_sizes = [int(i) for i in config.filter_sizes.split()]
        self.conv = nn.ModuleList([
            nn.Conv3d(1, config.num_filters, (1, i, config.embed))
            for i in filter_sizes
        ])
        self.relu = nn.ReLU()

    def conv_and_activate(self, x, conv):
        out = conv(x).squeeze(-1)
        out = F.max_pool2d(out, (1, out.size(3)))
        out = self.relu(out)
        out = out.squeeze(-1)
        out = out.permute(0, 2, 1)
        return out

    def forward(self, x):
        x = x.unsqueeze(1)
        out = torch.cat(
            [self.conv_and_activate(x, conv) for conv in self.conv], dim=2)
        return out


class Mid_CNN(nn.Module):
    """
    Input shape:
        3D tensor with shape: (batch, pad_size, embed_dim)
     Output shape
        2D tensor with shape: (batch, len(filter_sizes)*num_filters)
    """
    def __init__(self, config) -> None:
        super(Mid_CNN, self).__init__()
        filter_sizes = [int(i) for i in config.filter_sizes.split()]
        self.conv = nn.ModuleList([
            nn.Conv2d(1, config.num_filters, (i, config.embed))
            for i in filter_sizes
        ])
        self.relu = nn.ReLU()
        self.fc = nn.Linear(
            len(filter_sizes) * config.num_filters,
            len(filter_sizes) * config.num_filters)

    def conv_and_activate(self, x, conv):
        out = conv(x).squeeze(-1)
        out = F.max_pool1d(out, out.size(2))
        out = self.relu(out)
        return out.squeeze(-1)

    def forward(self, x):
        out = torch.cat(
            [self.conv_and_activate(x.unsqueeze(1), conv) for conv in self.conv], dim=1)
        out = self.fc(out)
        out = self.relu(out)
        return out.unsqueeze(1)


class LSTM(nn.Module):
    """
    Input shape:
        3D tensor with shape: (batch, K, len(filter_sizes)*num_filters)
     Output shape
        3D tensor with shape: (batch, K, num_directions * hidden_size)
    """
    def __init__(self, config) -> None:
        super(LSTM, self).__init__()
        input_size = len(config.filter_sizes.split()) * config.num_filters
        self.lstm = nn.LSTM(input_size,
                            config.hidden_size,
                            dropout=config.drop_out,
                            bidirectional=True,
                            batch_first=True,
                            num_layers=config.num_layers)
        # self.fc = nn.Linear(input_size, 2*config.hidden_size)

    def forward(self, x):
        if len(x.size()) == 2:
            x = x.unsqueeze(1)
        out, _ = self.lstm(x)
        # out = self.fc(x)
        return out


class Attention(nn.Module):
    """
    Input shape:
        3D tensor with shape: (batch, K, features)
     Output shape
        2D tensor with shape: (batch, features)
    """
    def __init__(self, config) -> None:
        super(Attention, self).__init__()
        self.w = nn.Parameter(torch.Tensor(config.hidden_size * 2).cuda(config.cuda_id))
        self.b = nn.Parameter(torch.Tensor(config.K).cuda(config.cuda_id))
        self.u = nn.Parameter(torch.Tensor(config.K, config.K).cuda(config.cuda_id))
        self._creat_weight()

    def _creat_weight(self, mean=0.0, std=0.05):
        self.w.data.normal_(mean, std)
        self.u.data.normal_(mean, std)

    def forward(self, x):
        uit = torch.matmul(x, self.w)
        uit += self.b
        uit = torch.matmul(uit, self.u)
        uit = torch.tanh(uit)
        uit = torch.exp(uit)
        ait = torch.sum(uit, dim=1).unsqueeze(1)
        uit = torch.div(uit, ait).unsqueeze(2)
        res = x * uit
        return torch.sum(res, dim=1)


class CBA(nn.Module):

    def __init__(self, config) -> None:
        super(CBA, self).__init__()
        self.lr_cnn = LR_CNN(config)
        self.mid_cnn = Mid_CNN(config)
        self.lr_lstm = LSTM(config)
        self.mid_lstm = LSTM(config)
        self.lr_attention = Attention(config)
        self.fc = nn.Linear(2 * config.hidden_size, 2 * config.hidden_size)
        self.fc1 = nn.Linear(6 * config.hidden_size, 2)
        self.relu = nn.ReLU()

    def forward(self, left, right, mid):
        l_out = self.lr_cnn(left)
        r_out = self.lr_cnn(right)
        mid_out = self.mid_cnn(mid)
        l_out = self.lr_lstm(l_out)
        r_out = self.lr_lstm(r_out)
        mid_out = self.mid_lstm(mid_out)
        mid_out = self.fc(mid_out)
        mid_out = self.relu(mid_out).squeeze(1)
        l_out = self.lr_attention(l_out)
        r_out = self.lr_attention(r_out)
        out = torch.cat((torch.cat((l_out, mid_out), dim=1), r_out), dim=1)
        out = self.fc1(out)
        out = F.softmax(out, dim=-1)
        return out

thanks

You can give retain_graph=True to backward() func to solve this.

thanks, giving retain_graph=True is worded, but it will slow the training a lot, any other solutions?

This error is raised, if the backward pass tries to calculate gradients in a part of the computation graph, which was already processed in a previous iterations and thus the intermediate tensors were freed.
Usually this can happen if you are trying to either call loss.backward() multiple times without recalculating the output and loss or if you are reusing tensors which are attached to the computation graph e.g. in a recurrent layer.

You’ve previously mentioned that calling detach() solved the issue. Where did you add this detach operation?

I calling detach() in class CBA, I think it only lets the program correct, but the result is incorrect. Finally, I find is because of the data processing. I use torchtext to embedding the data and then using Pytorch DataSet and DataLoader to wrap the data. When I process data using torchtext only, the problem is solved, and the result is correct.
But I don’t know why this happened, do you have any idea? Thanks

No, unfortunately I haven’t seen this issue before with torchtext.
Could you create an issue here and post a code snippet, if possible, so that we could track this issue?