Hello,
the problem I’m facing is quite complex and I’m not sure if I can reduce it to a minimal working example without cutting out some detail that might be the source of my problem, so I’ll try to be as clear as possible.
The problem I’m working on is the re-implementation of the paper “Dataset Distillation”. The overall objective is to modify a small subset of the training set so that training a network on this subset will lead to a good accuracy (potentially comparable to the one obtained by training the model on the whole dataset).
The algorithm consists of doing one or more SGD steps on the model using the small subsets then evaluating the model with a batch sampled from the training set and propagating the gradient back to the subset modifying the input samples. To implement this I had to use the library higher as I had to deal with second-order derivatives.
def distill(model, buff_data, img_lr, criterion, train_loader):
model.train()
buff_imgs, buff_trgs = buff_data
model_lr = 0.1
buff_opt = torch.optim.SGD([buff_imgs], lr=img_lr)
model_opt = torch.optim.SGD(model.parameters(), lr=model_lr)
for i in range(config['outer_steps']):
for step, (ds_imgs, ds_trgs) in enumerate(train_loader):
ds_imgs = ds_imgs.cuda()
ds_trgs = ds_trgs.cuda()
with higher.innerloop_ctx(model, model_opt) as (fmodel, diffopt):
acc_loss = None
for j in range(config['inner_steps']):
# First step modifies the model
buff_out = fmodel(buff_imgs)
buff_loss = criterion(buff_out, buff_trgs)
diffopt.step(buff_loss)
# Second step modifies the images
ds_out = fmodel(ds_imgs)
ds_loss = criterion(ds_out, ds_trgs)
acc_loss = acc_loss + ds_loss if acc_loss is not None else ds_loss
acc_loss.backward()
buff_opt.step()
buff_opt.zero_grad()
It worked fine but I then decided to implement the learning of the learning rates. It works like this: if I’m doing 10 inner SGD steps I’ll have a list of ten tensors containing a single element.
Each learning rate is specific to an SGD step and it’s learned by backpropagating on the training loss specific to that step.
def distill(model, buff_data, meta_lr, criterion, train_loader):
model.train()
buff_imgs, buff_trgs = buff_data
model_lr = 0.1
buff_opt = torch.optim.SGD([buff_imgs], lr=meta_lr)
# DIFF
model_opt = torch.optim.SGD(model.parameters(), lr=1)
lr_list = []
lr_opts = []
for _ in range(config['inner_steps']):
lr = torch.tensor([config['model_lr']], requires_grad=True, device='cuda')
lr_list.append(lr)
lr_opts.append(torch.optim.SGD([lr], meta_lr,))
# END DIFF
for i in range(config['outer_steps']):
for step, (ds_imgs, ds_trgs) in enumerate(train_loader):
ds_imgs = ds_imgs.cuda()
ds_trgs = ds_trgs.cuda()
with higher.innerloop_ctx(model, model_opt) as (fmodel, diffopt):
acc_loss = None
for j in range(config['inner_steps']):
# First step modifies the model
buff_out = fmodel(buff_imgs)
buff_loss = criterion(buff_out, buff_trgs)
diffopt.step(buff_loss)
# Second step modifies the images
ds_out = fmodel(ds_imgs)
ds_loss = criterion(ds_out, ds_trgs)
acc_loss = acc_loss + ds_loss if acc_loss is not None else ds_loss
# DIFF
lr_opts[j].zero_grad()
ds_loss.backward(retain_graph=True)
lr_opts[j].step()
# END DIFF
acc_loss.backward()
buff_opt.step()
buff_opt.zero_grad()
If I don’t use retain_graph=True
when calling backward to update the learning rate I get an error on the second iteration. If I set it to True it starts working but it’s extremely slow.
Since the update to a specific learning rate is computed using local information I can’t understand why I have to set retain_graph=True
. I’m surely doing something wrong but I don’t understand what. Could someone give me some help?