I’m trying to run a dataset distillation algorithm (see paper here) and I’ve encountered an implementation problem. In case you’re not familiar with this paper, I will expalin its main idea briefly. Basically, dataset distillation aims at synthesizing a small number of data on which a model is trained can achieve performance of model trained on original data. And this synthesizing procedure is done by 2nd order derivatives. Note that practice training on synthesized data can be extended to multiple steps and epochs.
Below is a minimal example of how to synthesize data
# rdata and rlabel denote the real data and real label # steps contain synthesized data, label and its corresponding lr def forward(model, rdata, rlabel, steps): model.train() w = model.get_param() params = [w] gws =  # first we train model on synthesized data for step_i, (data, label, lr) in enumerate(steps): with torch.enable_grad(): output = model.forward_with_param(data, w) loss = F.cross_entropy(output, label) gw, = torch.autograd.grad(loss, w, lr, create_graph=True) with torch.no_grad(): new_w = w.sub(gw).requires_grad_() params.append(new_w) gws.append(gw) w = new_w # final L, evaluated on real data model.eval() output = model.forward_with_param(rdata, params[-1]) ll = F.cross_entropy(output, rlabel) return (ll, params, gws) def backward(model, rdata, rlabel, steps, saved_for_backward): l, params, gws = saved_for_backward gdatas =  glrs =  # d denotes derivative dw, = torch.autograd.grad(l, (params[-1],)) # backward model.train() for (data, label, lr), w, gw in reversed(list(zip(steps, params, gws))): hvp_in = [w] hvp_in.append(data) hvp_in.append(lr) dgw = dw.neg() # gw is already weighted by lr, so simple negation hvp_grad = torch.autograd.grad( outputs=(gw,), inputs=hvp_in, grad_outputs=(dgw,) ) # Update for next iteration, i.e., previous step with torch.no_grad(): # Save the computed gdata and glrs gdatas.append(hvp_grad) glrs.append(hvp_grad) # Update dw # dw becomes the gradients w.r.t. the updated w for previous step dw.add_(hvp_grad) return gdatas, glrs saved = forward(model, rdata, rlabel, steps) grad_infos = backward(model, rdata, rlabel, steps, saved)
So here is my problem,during synthesizing multiple steps and multiple epochs data, those buffer( e.g. running mean/var in bn) should be tracked and added into the computation graph. Or it should stay a constant. Otherwise the backprop computation will be in correct since wrong buffer will be used.
My question is, how to add those buffers into computation graph while keep the bn layer normal? Or should I restore buffers of each step and epoch at the appropriate time?
Thanks a lot!
Edit: in case this example is too complex to understand, the algorithm is presented as follows: