Problem with Autograd

Hello,
the problem I’m facing is quite complex and I’m not sure if I can reduce it to a minimal working example without cutting out some detail that might be the source of my problem, so I’ll try to be as clear as possible.

The problem I’m working on is the re-implementation of the paper “Dataset Distillation”. The overall objective is to modify a small subset of the training set so that training a network on this subset will lead to a good accuracy (potentially comparable to the one obtained by training the model on the whole dataset).

The algorithm consists of doing one or more SGD steps on the model using the small subsets then evaluating the model with a batch sampled from the training set and propagating the gradient back to the subset modifying the input samples. To implement this I had to use the library higher as I had to deal with second-order derivatives.

def distill(model, buff_data, img_lr, criterion, train_loader):
    model.train()

    buff_imgs, buff_trgs = buff_data

    model_lr = 0.1

    buff_opt = torch.optim.SGD([buff_imgs], lr=img_lr)
    model_opt = torch.optim.SGD(model.parameters(), lr=model_lr)

    for i in range(config['outer_steps']):
        for step, (ds_imgs, ds_trgs) in enumerate(train_loader):
            ds_imgs = ds_imgs.cuda()
            ds_trgs = ds_trgs.cuda()
 
            with higher.innerloop_ctx(model, model_opt) as (fmodel, diffopt):
								acc_loss = None
                for j in range(config['inner_steps']):
                    # First step modifies the model
                    buff_out = fmodel(buff_imgs)
                    buff_loss = criterion(buff_out, buff_trgs)
                    diffopt.step(buff_loss)

		            # Second step modifies the images
		            ds_out = fmodel(ds_imgs)
		            ds_loss = criterion(ds_out, ds_trgs)
		            acc_loss = acc_loss + ds_loss if acc_loss is not None else ds_loss

                acc_loss.backward()
                buff_opt.step()
                buff_opt.zero_grad()

It worked fine but I then decided to implement the learning of the learning rates. It works like this: if I’m doing 10 inner SGD steps I’ll have a list of ten tensors containing a single element.
Each learning rate is specific to an SGD step and it’s learned by backpropagating on the training loss specific to that step.

def distill(model, buff_data, meta_lr, criterion, train_loader):
    model.train()

    buff_imgs, buff_trgs = buff_data

    model_lr = 0.1

    buff_opt = torch.optim.SGD([buff_imgs], lr=meta_lr)
    # DIFF
    model_opt = torch.optim.SGD(model.parameters(), lr=1)
    lr_list = []
    lr_opts = []
    for _ in range(config['inner_steps']):
        lr = torch.tensor([config['model_lr']], requires_grad=True, device='cuda')
        lr_list.append(lr)
        lr_opts.append(torch.optim.SGD([lr], meta_lr,))
    # END DIFF

    for i in range(config['outer_steps']):
        for step, (ds_imgs, ds_trgs) in enumerate(train_loader):
            ds_imgs = ds_imgs.cuda()
            ds_trgs = ds_trgs.cuda()
 
            with higher.innerloop_ctx(model, model_opt) as (fmodel, diffopt):
								acc_loss = None
                for j in range(config['inner_steps']):
                    # First step modifies the model
                    buff_out = fmodel(buff_imgs)
                    buff_loss = criterion(buff_out, buff_trgs)
                    diffopt.step(buff_loss)

		            # Second step modifies the images
		            ds_out = fmodel(ds_imgs)
		            ds_loss = criterion(ds_out, ds_trgs)
		            acc_loss = acc_loss + ds_loss if acc_loss is not None else ds_loss

                    # DIFF
		            lr_opts[j].zero_grad()
                    ds_loss.backward(retain_graph=True)
                    lr_opts[j].step()
                    # END DIFF

                acc_loss.backward()
                buff_opt.step()
                buff_opt.zero_grad()

If I don’t use retain_graph=True when calling backward to update the learning rate I get an error on the second iteration. If I set it to True it starts working but it’s extremely slow.

Since the update to a specific learning rate is computed using local information I can’t understand why I have to set retain_graph=True. I’m surely doing something wrong but I don’t understand what. Could someone give me some help?

Is that on the ds_loss backward or acc_loss backward? Because the acc_loss will backprop in the same graph as the ds_loss so I think the retain_graph is required here.
For the slowdown, I guess it depends how many times you unroll your loop as you differentiate through all these iterations.

It’s on the ds_loss backward (you can see it in the second snippet inside the inner and after the # DIFF comment).

Because the acc_loss will backprop in the same graph as the ds_loss so I think the retain_graph is required here.

I thought about it and I had also tried to change ds_loss.backward(retain_graph=True) to criterion(fmodel(ds_imgs, ds_trgs)).backward() to create a separate path on the graph which does not interfere with the following backpropagation on acc_loss. Unfortunately, that didn’t help much as the problem occur earlier when I call for the second time ds_loss.backward(). Since every iteration optimizes a different learning rate I don’t understand why should I retain the graph there.

For the slowdown, I guess it depends how many times you unroll your loop as you differentiate through all these iterations.

That’s fair but for the same number of inner iterations, the second snippet takes a lot more. I don’t think it should since I’m learning literally just one new parameter and it doesn’t have to backprop through several steps (just through one since I update it after I evaluate the model on the training set).

I think that whatever the bug is the two “symptoms” are correlated i.e. it’s taking so much cause it’s doing something which requires retain_graph=True while what I’m actually trying to do shouldn’t need it and should do a much more simple (and cost efficient) update

That sounds good.

Could you try to use torchviz to print the graph and see where it overlaps (give it as input the loss from the first and second iteration and see where the two graphs start to overlap)? As it is a bit hard to say from just looking at the code indeed.

Sure, if it’s hard for me that I’ve written it, it’s even harder for you since you are not familiar with it.
I didn’t know about torchviz and I had some problem using it on windows but here is the first graph:


I couldn’t update the second because it’s too big but it basically the second iteration with the first one stacked above it (since to backpropagate on acc loss all the previous iterations are needed).

Now, I don’t know if you can download it in full resolution but on the center-left of the image, you can see a one-element tensor i.e. the learning rate for that step.

In the graph at the second iteration, there are both the learning rate for the second step and, the one for the previous step. Basically, what I think is happening is that, at a given inner iteration, I’m not only updating the current learning rate but also all of the past ones.

Is there any way I can avoid this?

The image you shared if full resolution if you click on it. Thanks!

Could it be that the j index is a typo in your code? :slight_smile:

No, sorry I just mistyped it in the snippet (I edited now); j is the index of the inner loop so it makes sense.

I think that, since fmodel is functional and the graph of all inner iterations is kept in order to backpropagate to buff_imgs, when I call backward on ds_loss to update the current learning rate, the gradient actually flows back trough all the graph.

What could I do to prevent this? I need the whole graph to update buff_imgs but when I update the learning rate I just need the gradient to flow to the current learning rate tensor.

I think one solution could be to calculate the output of fmodel two times: one to update buff_imgs after all iterations and the other with no_grad() to update the learning rate. Would that work?

1 Like

If you use nightly build, you can pass inputs= to .backward() to specify what you want to compute gradients for. In this case ds_loss.backward(inputs=(lr_list[j],), retrain_graph=True).

Otherwise, you can do grad, = autograd.grad(ds_loss, lr_list[j], retain_graph=True); lr_list[j].grad = grad but it is not as nice :smiley:

Thank you so much, that’s just what I needed!
Yeah, the nightly build version is cooler but I’ll stick with the standard one for now.

Side question: do you plan to add graph visualization tools directly to PyTorch?

1 Like

torchviz is more or less the official way to do it :stuck_out_tongue:
The main drawback is that it is very hard to read from a user point of view in general.
Note that the tensorboard implementation has a nice tool to visualize graphs at the torch.nn level as well :slight_smile:

1 Like