# Pytorch implementation for higher order gradients computation

Hi all,

I would like to compute the gradients of a matrix T (which has T.requires_grad(True)). However, I am not sure how to achieve this in pytorch/python.

First, the matrix T multiplies with the output from an neural network (“net” in below code), trained on dataset s. The loss at this stage is denoted as loss_. The weights of the net is updated with grads computed by torch.autograd.grad, called W_hat.
Then, the net is trained on another dataset g, without matrix T. Loss is also computed. I want to compute the gradients of T with respect to this loss, which comes from the net with previously updated weights (should involve T) and trained on dataset g. Or equivalently, this paper: [2006.05697] Meta Transition Adaptation for Robust Deep Learning with Noisy Labels. In particular, equation 9 and 10.

Below is my code. The error thrown is due to T.grad=None:
Traceback (most recent call last):
TypeError: unsupported operand type(s) for *: ‘float’ and ‘NoneType’

Is it that I should include T in grads computation when net is trained on dataset s so that later backprop would not have this error? If yes, how should I do this (must I write the gradients with T then implement it with code)? Thank you very much!

Current code:

T = numpy.random.randint(1, size=(10, 10))# an example matrix T
T = torch.from_numpy(T)

net.train() # net is resnet34
data_g, target_g = data[0][0].cuda(), data[0][1].cuda()
data_s, target_s = data[1][0].cuda(), data[1][1].cuda()

``````  #copy parameters for later W_hat computation
original_weights = OrderedDict()
for name, param in net.named_parameters():
print(name)
else:
original_weights[name] = copy.deepcopy(param)
original_weights_keys = tuple(original_weights.keys())

#for equation 11 to have normal backprop when trained on data_s
model = copy.deepcopy(net)

#for T, equation 9
logits = net(data_s)
pre1 = T[torch.cuda.LongTensor(target_s.data)]
pre2 = torch.mul(F.softmax(logits, dim=1), pre1)
loss_ = -(torch.log(pre2.sum(1))).sum(0)
print('loss_', loss_)

continue
else: #this update net's parameters as W_hat