PyTorch Gradients

Hi,

You most likely want to open a new topic for this and add the distributed tag.
I don’t know if these constructs are differentiable or not tbh.

@albanD Hi,

I got the same error, and I view some examples. However, I still don’t know how to solve my problem.
Here is code,

    for iter, input in enumerate(train_loader):
        template = input['template']            #read input
        search = input['search']
        label_cls = input['out_label']
        reg_label = input['reg_label']
        reg_weight = input['reg_weight']

        cfg_cnn = [(2, 16, 2, 0, 3),
                   (16, 32, 2, 0, 3),
                   (32, 64, 2, 0, 3),
                   (64, 128, 1, 1, 3),
                   (128, 256, 1, 1, 3)]
        cfg_kernel = [127, 63, 31, 31, 31]
        cfg_kernel_first = [63,31,15,15,15]

        c1_m = c1_s = torch.zeros(1, cfg_cnn[0][1], cfg_kernel[0], cfg_kernel[0]).to(device)
        c2_m = c2_s = torch.zeros(1, cfg_cnn[1][1], cfg_kernel[1], cfg_kernel[1]).to(device)
        c3_m = c3_s = torch.zeros(1, cfg_cnn[2][1], cfg_kernel[2], cfg_kernel[2]).to(device)
        trans_snn = [c1_m, c1_s, c2_m, c2_s, c3_m, c3_s]          # use this list

        for i in range(search.shape[-1]):
            cls_loss_ori, cls_loss_align, reg_loss, trans_snn = model(template.squeeze(-1), \
                                                                   search[:,:,:,:,i], trans_snn,\
                                                                label_cls[:,:,:,i], \
                                                               reg_target=reg_label[:,:,:,:,i], reg_weight=reg_weight[:,:,:,i])
             .......
            loss = cls_loss_ori + cls_loss_align + reg_loss
            optimizer.zero_grad()
            loss.backward()

I think the reason why this code is error is that in the loop, I keep updating the value of the variable trans_snn. However, I have no idea about how to solve it by renaming trans_snn. Looking for your help. Thank you very much!

if I remove trans_snn = [c1_m, c1_s, c2_m, c2_s, c3_m, c3_s] into the loop,
the error will not happen. However, I need the updated trans_snn .