Turning Pytorch Model Weights to non-leaf nodes

How can I make modifications to model weights in a way that the update stays on the Pytorch computation graph? (For instance, if you needed second-order gradients)

By default, model weights are forced to be leaf nodes. And modifications are either in place with gradients off / only happen on a copy of the parameters which can not be forced back to the model.

Taking second derivatives via torch.autograd will not quite work for me since I actually need to make the model update.

To specify the problem more concretely, my objective is something as follows:

  1. x β†’ model weights (a loss based on some variable x leads to a change in model weights)
  2. new model weights β†’ auxiliary loss (l2) (the model outputs on some static input are now calculated on these new weights)

i need to find d (l2)/ d(x). So basically, i need to simulate the actual change in model weights to get the new loss l2.

1 Like

Since weights are initialized as leaf nodes , I dont think even using torch.autograd would help in this case. If your computation graph was small, such that you could update each parameter separetely and do a manual forward pass without using pytorch elements, then this could have been done. But since torch.autograd only accepts scalar numerators calculating dw/dx would not be possible.
@ptrblck, @albanD any thoughts

Hi,

This is indeed a bit challenging today.
A solution that is WIP for such ideas is to provide a stateless API for which you can pass weights directly that are used instead of the ones stored in your model.
The function is torch.nn.utils._stateless.functional_call(model, params_and_buffers, args, kwargs) docstring here (it is not documented on the website but will be for next version).

An example of how this is used is here for example: Forward-mode Automatic Differentiation (Beta) β€” PyTorch Tutorials 1.11.0+cu102 documentation

1 Like

Hi @albanD
Thanks a lot for your inputs. We will go through the functionl_call approach.
Meanwhile, we came across this package called β€œhigher” from FAIR (GitHub - facebookresearch/higher: higher is a pytorch library allowing users to obtain higher order gradients over losses spanning training loops rather than individual training steps.), which seems to allow to calculate second order gradients in pytorch.
I am sharing a toy implementation for your perusal. Can you please look at it and share your views !

import torch
from torchvision import models
import copy
import higher
def _objective(model, poison_x, x, y):
    adv_model = copy.deepcopy(model)
    optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, adv_model.parameters()), lr=0.1)
    # Wrap the model into a meta-object that allows for meta-learning steps via monkeypatching:
    with higher.innerloop_ctx(adv_model, optimizer, copy_initial_weights=False) as (fmodel, fopt):
        for _ in range(1):
            #the first step : x --> model weights
            outputs = fmodel(torch.cat([x,poison_x]))
            #calculating a loss 1 (entropy here) which depends on the prediction of the model for posioned input
            poison_loss = -(outputs.softmax(1) * outputs.log_softmax(1)).sum(1).mean(0) 
            #updating the model weights to minimize the entropy of predictions on the poisoned input
            fopt.step(poison_loss) 
    
    #calculating auxilliary loss (l2) using new model weights 
    new_preds = fmodel(x)
    #auxilliary loss with the updated model parameteres
    new_loss = torch.nn.CrossEntropyLoss()(new_preds, y)
    return new_loss

def get_poison_delta(model, input, y):
    #initializing posion examples to zeros for simplicity (posion_delta represents the poisoned_input that will be added to the dataset)
    poison_delta = torch.zeros_like(input, requires_grad=True)
    poison_optimizer = torch.optim.SGD([poison_delta], lr=0.1, momentum=0.9, weight_decay=0)
    num_poison_iters = 10
    for _ in range(num_poison_iters):
        poison_optimizer.zero_grad()
        # Gradient step
        loss = _objective(model, poison_delta, input, y)
        #d(l2)/d(posion_delta)
        #*****************THE SECOND ORDER GRADIENTS******************
        poison_delta.grad, = torch.autograd.grad(loss, poison_delta, retain_graph=False, create_graph=False, only_inputs=True)
        # Optim step
        poison_optimizer.step()
        # Projection step (omitted for simplicity)
    poison_delta.requires_grad = False
    return poison_delta

def main():
    model = models.resnet18().cuda()
    print("Starting")
    batch_size, dims, pix = 100, 3, 32
    # using a random input just for a minimalistic example to test gradient flow
    input = torch.rand(batch_size, dims, pix, pix).cuda()
    y = torch.randint(0, 10, (batch_size,)).cuda()
    
    poisoned_image = get_poison_delta(model, input, y)
    
    # print(poisoned_image)
if __name__ == "__main__":
    main()

You can run this code by this CLI : python file_name.py

It seems that using the higher package for forward pass and optimizer.step() keeps the model weights as non-leaf nodes (update stays on the graph). Which later than allows us to use torch.autograd() to compute second order gradients.

Thanks
Sachin and Pratyush

2 Likes