# Convert tuple to tensor without breaking graph

I want to be able to take the gradient of the norm-squared gradient of the loss function of a neural network. That’s a bit of a mouthful: if theta are the parameters of a neural net (unrolled into a vector), and L is the loss function, then let g be the gradient of L with respect to theta. Letting ||g||^2 be the norm-squared of the gradient, I would like to take the gradient of this with respect to theta. (This is related to the question of computing the Hessian vector product).

Here’s what I tried:

linear = nn.Linear(10, 20)
x = torch.randn(1, 10)
L = linear(x).sum()
z.backward()

The problem this runs into is that grad is a tuple of tensors, and not a single unrolled tensor. Every way I tried of converting the tuple grad into an unrolled vector ends up breaking the graph, so that z.backwards() either returns an error or None.

1 Like

Hi,

The simplest way would be to do each of them one by one:

``````z = 0
z = z + g.pow(2).sum()
``````

Replacing

with

z = 0
z = z + g.pow(2).sum()

Actually returns an error,

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

I think this is the issue - linear.parameters() returns a tuple of tensors, none of which require gradients. If I had defined the linear neural net by hand, rather than using the module, then I could declare that the weight tensor and bias vector both required gradients and this should work.

Since I want to use a deep neural network with all sorts of complicated layers, I would like to avoid needing to define it by hand in this way.

Parameters of a net all require gradients by default (that why they gradients are computed and you can train them). Unless you set them to False, they should be True.

Ok thanks… I’m still confused about what’s going wrong then. Here’s a minimal example:

import torch
linear = torch.nn.Linear(10, 20)
x = torch.randn(1, 10)
L = linear(x).sum()
gnorm = 0
gnorm = gnorm + g.pow(2).sum()
gnorm.backward()

I could replace the last line with

If your function is not linear like: `L = linear(x).sum()**2` then it works as expected.