Pytorch implementation for higher order gradients computation

Hi all,

I would like to compute the gradients of a matrix T (which has T.requires_grad(True)). However, I am not sure how to achieve this in pytorch/python.

First, the matrix T multiplies with the output from an neural network (“net” in below code), trained on dataset s. The loss at this stage is denoted as loss_. The weights of the net is updated with grads computed by torch.autograd.grad, called W_hat.
Then, the net is trained on another dataset g, without matrix T. Loss is also computed. I want to compute the gradients of T with respect to this loss, which comes from the net with previously updated weights (should involve T) and trained on dataset g. Or equivalently, this paper: [2006.05697] Meta Transition Adaptation for Robust Deep Learning with Noisy Labels. In particular, equation 9 and 10.

Below is my code. The error thrown is due to T.grad=None:
Traceback (most recent call last):
T -= learning_rate * T.grad
TypeError: unsupported operand type(s) for *: ‘float’ and ‘NoneType’

Is it that I should include T in grads computation when net is trained on dataset s so that later backprop would not have this error? If yes, how should I do this (must I write the gradients with T then implement it with code)? Thank you very much!

Current code:

T = numpy.random.randint(1, size=(10, 10))# an example matrix T
T = torch.from_numpy(T)
T = T.requires_grad_(True).cuda()

net.train() # net is resnet34
for batch_idx, data in enumerate(zip(itertools.cycle(train_g_loader), train_s_loader)):#data
data_g, target_g = data[0][0].cuda(), data[0][1].cuda()
data_s, target_s = data[1][0].cuda(), data[1][1].cuda()

  #copy parameters for later W_hat computation
  original_weights = OrderedDict()
  for name, param in net.named_parameters():
        if not param.requires_grad:
            print(name)   
        else:
            original_weights[name] = copy.deepcopy(param)
    original_weights_keys = tuple(original_weights.keys())
    
    #for equation 11 to have normal backprop when trained on data_s
    model = copy.deepcopy(net)

    #for T, equation 9
    with torch.enable_grad():
        logits = net(data_s)
        pre1 = T[torch.cuda.LongTensor(target_s.data)]
        pre2 = torch.mul(F.softmax(logits, dim=1), pre1)
        loss_ = -(torch.log(pre2.sum(1))).sum(0)
        print('loss_', loss_)

        #manually compute gradients
        grads = torch.autograd.grad(loss_, net.parameters(),create_graph=True, only_inputs=False, allow_unused=True)  
        for param, grad in zip(original_weights_keys, grads):                
            if grad is None: 
                continue
            else: #this update net's parameters as W_hat
                net.state_dict()[param] -= learning_r * grad

    with torch.set_grad_enabled(True):
        pre = net(data_g)
    loss = F.cross_entropy(pre, target_g, size_average=False)
    loss.backward() 
   #below throws error, because T.grad=None
    with torch.no_grad():
        T -= learning_rate * T.grad
        T.clamp(min=0, max=1.0)
        T.grad.zero_()

Since I want to backprop for T, it means T should exist in the 2nd stage of training. Write the gradients with T and code it should have this resolved, but still hope that there can be easier ways of doing this.