Backpropagate through gradient step

Sorry for my questions, though I am fairly new into this topic. I am trying to learn a model to guide a gradient step in MAP restoration. My first question is simply if it is possible to backpropagate a gradient step (with adam optimizer)?
My second question is how to do this?

My code, for one step, is following (shotened):

input_img = nn.Parameter(input_img, requires_grad=False)
prior = nn.Parameter(prior, requires_grad=False)
img_ano = nn.Parameter(input_img.clone(),requires_grad=True)

MAP_optimizer = optim.Adam([img_ano], lr=map_step_size)
net_optimizer = optim.Adam(net.parameters(), lr=net_lr_rate)

gfunc = torch.sum((prior.view(-1, prior.numel()) - img_ano.view(-1,img_ano.numel())).pow(2))


out = net(
img_ano.grad = img_ano.grad + out

MAP_optimizer.step() # Update img_ano
loss = diceloss(img_ano, input_seg)
net_optimizer.step() # Update network parameters

The problem I am having is that after loss.backward(), network parameters gradient is of “Nonetype” (zero) which is not true if we were aloud to backpropagate through the gradient step. How do I solve my problem?


If you’re new to this, I would recommend using a library built for this like higher:

1 Like

Thank you for the library! But I don’t think it solves my problem, as I want to update the gradient before the optimizer step with:

out = net(
img_ano.grad = img_ano.grad + out

This does not seem possible in the library as you need to send the loss function with the gradient step (optim.loss(loss_function)) and in my case I want to change the gradient I get after gfunc.backward(). Any suggestions?

You want to change that gradient in the inner step or the outer step?

The inner step. As described below:

inner_loss = l2_loss(img_ano,prior)

img_ano.grad += out( # Changing the gradient by adding the output of network

img_ano = img_ano + step_size * img_ano.grad # (restore_optimizer.step()) Here I want to take a corresponding step with adam optimizer instead

loss = diceloss(img_ano,target_img)
loss.backward() # Gather network params gradients, ie. backpropagate the gradient step 
net_optimizer.step() # Update network params 

Thanks for the help!

I think you can specify a grad_callback in higher if you want to modify them before the optimizer step :slight_smile:

1 Like

Oh, missed that argument. Thanks a lot! Hopefully this will do the work.