Gradient Normalization Loss Can't Be Computed

Hi – I’m trying to implement the GradNorm algorithm from this paper. I’m closely following the code from this repository. However, whenever I run it, I get:

model.task_loss_weights.grad = torch.autograd.grad(grad_norm_loss, model.task_loss_weights)[0]
  File "/home/ubuntu/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/torch/autograd/__init__.py", line 192, in grad
    inputs, allow_unused)
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

I can see that grad_norm_loss doesn’t have a gradient, so I set requires_grad=True explicitly, at which point I got:

RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_
unused=True if this is the desired behavior.    

When I set allow_unused=True, I got None back as my gradient.

For context, the paper specifies:

Screen Shot 2020-11-17 at 3.26.01 PM

which is why requires_grad for the constant_term of the grad_norm_loss is set explicitly to False. For reference, here is the relevant section of code:

How can I work around this?

1 Like

Apologies, I didn’t include the code – here it is:

# run n_iter iterations of training
    for t in range(n_iterations):

        # get a single batch
        for (it, batch) in enumerate(data_loader):
            # get the X and the targets values
            X = batch[0]
            ts = batch[1]
            if torch.cuda.is_available():
                X = X.cuda()
                ts = ts.cuda()

            # evaluate each task loss L_i(t)
            task_loss = model(X, ts) # this will do a forward pass in the model and will also evaluate the loss
            # compute the weighted loss w_i(t) * L_i(t)
            weighted_task_loss = torch.mul(model.weights, task_loss)
            # initialize the initial loss L(0) if t=0
            if t == 0:
                # set L(0)
                if torch.cuda.is_available():
                    initial_task_loss = task_loss.data.cpu()
                else:
                    initial_task_loss = task_loss.data
                initial_task_loss = initial_task_loss.numpy()

            # get the total loss
            loss = torch.sum(weighted_task_loss)
            # clear the gradients
            optimizer.zero_grad()
            # do the backward pass to compute the gradients for the whole set of weights
            # This is equivalent to compute each \nabla_W L_i(t)
            loss.backward(retain_graph=True)

            # set the gradients of w_i(t) to zero because these gradients have to be updated using the GradNorm loss
            #print('Before turning to 0: {}'.format(model.weights.grad))
            model.weights.grad.data = model.weights.grad.data * 0.0
            #print('Turning to 0: {}'.format(model.weights.grad))


            # switch for each weighting algorithm:
            # --> grad norm
            if args.mode == 'grad_norm':
                
                # get layer of shared weights
                W = model.get_last_shared_layer()

                # get the gradient norms for each of the tasks
                # G^{(i)}_w(t) 
                norms = []
                for i in range(len(task_loss)):
                    # get the gradient of this task loss with respect to the shared parameters
                    gygw = torch.autograd.grad(task_loss[i], W.parameters(), retain_graph=True)
                    # compute the norm
                    norms.append(torch.norm(torch.mul(model.weights[i], gygw[0])))
                norms = torch.stack(norms)
                #print('G_w(t): {}'.format(norms))


                # compute the inverse training rate r_i(t) 
                # \curl{L}_i 
                if torch.cuda.is_available():
                    loss_ratio = task_loss.data.cpu().numpy() / initial_task_loss
                else:
                    loss_ratio = task_loss.data.numpy() / initial_task_loss
                # r_i(t)
                inverse_train_rate = loss_ratio / np.mean(loss_ratio)
                #print('r_i(t): {}'.format(inverse_train_rate))


                # compute the mean norm \tilde{G}_w(t) 
                if torch.cuda.is_available():
                    mean_norm = np.mean(norms.data.cpu().numpy())
                else:
                    mean_norm = np.mean(norms.data.numpy())
                #print('tilde G_w(t): {}'.format(mean_norm))


                # compute the GradNorm loss 
                # this term has to remain constant
                constant_term = torch.tensor(mean_norm * (inverse_train_rate ** args.alpha), requires_grad=False)
                if torch.cuda.is_available():
                    constant_term = constant_term.cuda()
                #print('Constant term: {}'.format(constant_term))
                # this is the GradNorm loss itself
                grad_norm_loss = torch.tensor(torch.sum(torch.abs(norms - constant_term)))
                #print('GradNorm loss {}'.format(grad_norm_loss))

                # compute the gradient for the weights
                model.weights.grad = torch.autograd.grad(grad_norm_loss, model.weights)[0]

            # do a step with the optimizer
            optimizer.step()

@albanD sorry to bug you, but tagging because I’ve seen your helpful replies on some other autograd-related questions I was looking at to try to figure this out! Do you have any ideas on this?

Hi,

I did see your question but I don’t think I have much to say.
If a re-wrap Tensors manually then you will break the autograd graph. And so you won’t be able to get gradients indeed.

I think you should remove all the torch.tensor from your code as you keep breaking the graph even in places where you shouldn’t.
Also you should never use .data anymore. You can use .detach() if you want to explicitely break the autograd graph.
Finally, keep in mind that any op that is not done with pytorcch primitives won’t be differentiable by the autograd.

1 Like

Thanks, I will try this and see what happens.

grad_norm_loss = torch.sum(torch.abs(norms - constant_term))