[Iteratively update and use variables within the loop] RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time

Hi,

I got the above error, and I view some solutions. However, I still don’t know how to solve my problem.
Here is code,

    for iter, input in enumerate(train_loader):
        template = input['template']            #read input
        search = input['search']
        label_cls = input['out_label']
        reg_label = input['reg_label']
        reg_weight = input['reg_weight']

        cfg_cnn = [(2, 16, 2, 0, 3),
                   (16, 32, 2, 0, 3),
                   (32, 64, 2, 0, 3),
                   (64, 128, 1, 1, 3),
                   (128, 256, 1, 1, 3)]
        cfg_kernel = [127, 63, 31, 31, 31]
        cfg_kernel_first = [63,31,15,15,15]

        c1_m = c1_s = torch.zeros(1, cfg_cnn[0][1], cfg_kernel[0], cfg_kernel[0]).to(device)
        c2_m = c2_s = torch.zeros(1, cfg_cnn[1][1], cfg_kernel[1], cfg_kernel[1]).to(device)
        c3_m = c3_s = torch.zeros(1, cfg_cnn[2][1], cfg_kernel[2], cfg_kernel[2]).to(device)
        trans_snn = [c1_m, c1_s, c2_m, c2_s, c3_m, c3_s]          # use this list

        for i in range(search.shape[-1]):
            cls_loss_ori, cls_loss_align, reg_loss, trans_snn = model(template.squeeze(-1), \
                                                                   search[:,:,:,:,i], trans_snn,\
                                                                label_cls[:,:,:,i], \
                                                               reg_target=reg_label[:,:,:,:,i], reg_weight=reg_weight[:,:,:,i])
             .......
            loss = cls_loss_ori + cls_loss_align + reg_loss
            optimizer.zero_grad()
            loss.backward()

I think the reason why this code is error is that in the loop, I keep updating the value of the variable trans_snn. However, I have no idea about how to solve it by renaming trans_snn.

if I remove trans_snn = [c1_m, c1_s, c2_m, c2_s, c3_m, c3_s] into the second loop, the error will not happen. However, I need the updated trans_snn.

Could you change your code example to include the forward pass? The one place where I have seen this error come up is when you do the forward pass once, and then do the backward pass more than once before the next forward pass.

Hi, thanks for your reply.
Here is the forward pass of one epoch,

    for iter, input in enumerate(train_loader):
        data_time.update(time.time() - end)
        template = input['template']
        search = input['search']
        label_cls = input['out_label']
        reg_label = input['reg_label']
        reg_weight = input['reg_weight']

        cfg_cnn = [(2, 16, 2, 0, 3),
                   (16, 32, 2, 0, 3),
                   (32, 64, 2, 0, 3),
                   (64, 128, 1, 1, 3),
                   (128, 256, 1, 1, 3)]
        cfg_kernel = [127, 63, 31, 31, 31]
        cfg_kernel_first = [63,31,15,15,15]

        c1_mem = c1_spike = torch.zeros(1, cfg_cnn[0][1], cfg_kernel[0], cfg_kernel[0]).to(device)
        c2_mem = c2_spike = torch.zeros(1, cfg_cnn[1][1], cfg_kernel[1], cfg_kernel[1]).to(device)
        c3_mem = c3_spike = torch.zeros(1, cfg_cnn[2][1], cfg_kernel[2], cfg_kernel[2]).to(device)
        trans_snn = [c1_mem, c1_spike, c2_mem, c2_spike, c3_mem, c3_spike]
        for i in range(search.shape[-1]):
            cls_loss_ori, cls_loss_align, reg_loss, trans_snn = model(template.squeeze(-1), search[:,:,:,:,i], trans_snn, label_cls[:,:,:,i], \
                                                               reg_target=reg_label[:,:,:,:,i], reg_weight=reg_weight[:,:,:,i])

            cls_loss_ori = torch.mean(cls_loss_ori)
            reg_loss = torch.mean(reg_loss)

            if cls_loss_align is not None:
                cls_loss_align = torch.mean(cls_loss_align)
                loss = cls_loss_ori + cls_loss_align + reg_loss
            else:
                cls_loss_align = 0
                loss = cls_loss_ori + reg_loss

            loss = torch.mean(loss)
            # compute gradient and do update step
            optimizer.zero_grad()
            loss.backward()

            # torch.nn.utils.clip_grad_norm(model.parameters(), 10)  # gradient clip

            if is_valid_number(loss.item()):
                optimizer.step()

            # record loss
            loss = loss.item()
            losses.update(loss, template.size(0))

            cls_loss_ori = cls_loss_ori.item()
            cls_losses_ori.update(cls_loss_ori, template.size(0))

            try:
                cls_loss_align = cls_loss_align.item()
            except:
                cls_loss_align = 0

            cls_losses_align.update(cls_loss_align, template.size(0))

            reg_loss = reg_loss.item()
            reg_losses.update(reg_loss, template.size(0))

            batch_time.update(time.time() - end)
            end = time.time()

        if (iter + 1) % cfg.PRINT_FREQ == 0:
            logger.info(
                'Epoch: [{0}][{1}/{2}] lr: {lr:.7f}\t Batch Time: {batch_time.avg:.3f}s \t Data Time:{data_time.avg:.3f}s \t CLS_ORI Loss:{cls_loss_ori.avg:.5f} \t CLS_ALIGN Loss:{cls_loss_align.avg:.5f} \t REG Loss:{reg_loss.avg:.5f} \t Loss:{loss.avg:.5f}'.format(
                    epoch, iter + 1, len(train_loader), lr=cur_lr, batch_time=batch_time, data_time=data_time,
                    loss=losses, cls_loss_ori=cls_losses_ori, cls_loss_align=cls_losses_align, reg_loss=reg_losses))

            print_speed((epoch - 1) * len(train_loader) + iter + 1, batch_time.avg,
                        cfg.OCEAN.TRAIN.END_EPOCH * len(train_loader), logger)

        # write to tensorboard
        writer = writer_dict['writer']
        global_steps = writer_dict['train_global_steps']
        writer.add_scalar('train_loss', loss, global_steps)
        writer_dict['train_global_steps'] = global_steps + 1

    return model, writer_dict

Thanks for the code. As far as I can see, this code does not invoke backward more than once per forward pass. So I am not sure why the error comes up. Sorry about that!

Thanks for your reply! :grinning: