Hi!
I’m trying to run a dataset distillation algorithm (see paper here) and I’ve encountered an implementation problem. In case you’re not familiar with this paper, I will expalin its main idea briefly. Basically, dataset distillation aims at synthesizing a small number of data on which a model is trained can achieve performance of model trained on original data. And this synthesizing procedure is done by 2nd order derivatives. Note that practice training on synthesized data can be extended to multiple steps and epochs.

Below is a minimal example of how to synthesize data

``````# rdata and rlabel denote the real data and real label
# steps contain  synthesized data, label and its corresponding lr
def forward(model, rdata, rlabel, steps):
model.train()
w = model.get_param()
params = [w]
gws = []

# first we train model on synthesized data
for step_i, (data, label, lr) in enumerate(steps):
output = model.forward_with_param(data, w)
loss = F.cross_entropy(output, label)

params.append(new_w)
gws.append(gw)
w = new_w

# final L, evaluated on real data
model.eval()
output = model.forward_with_param(rdata, params[-1])
ll = F.cross_entropy(output, rlabel)
return (ll, params, gws)

def backward(model, rdata, rlabel, steps, saved_for_backward):
l, params, gws = saved_for_backward
gdatas = []
glrs = []
# d denotes derivative

# backward
model.train()
for (data, label, lr), w, gw in reversed(list(zip(steps, params, gws))):
hvp_in = [w]
hvp_in.append(data)
hvp_in.append(lr)
dgw = dw.neg()  # gw is already weighted by lr, so simple negation
outputs=(gw,),
inputs=hvp_in,
)
# Update for next iteration, i.e., previous step
# Save the computed gdata and glrs

# Update dw
# dw becomes the gradients w.r.t. the updated w for previous step

return gdatas, glrs

saved = forward(model, rdata, rlabel, steps)
grad_infos = backward(model, rdata, rlabel, steps, saved)
``````

So here is my problem,during synthesizing multiple steps and multiple epochs data, those buffer( e.g. running mean/var in bn) should be tracked and added into the computation graph. Or it should stay a constant. Otherwise the backprop computation will be in correct since wrong buffer will be used.

My question is, how to add those buffers into computation graph while keep the bn layer normal? Or should I restore buffers of each step and epoch at the appropriate time?

Thanks a lot!

Edit: in case this example is too complex to understand, the algorithm is presented as follows:

Hi,

I’m not sure what you mean with “how to add those buffers into computation graph”.
Do you want them to have a fixed value and never be updated?

I mean, if the model is trained on synthesized data with multiple steps / multiple epochs, and we just backprop like this `forward` and `backward` example, then the gradient computation would be incorrect. Because after multiple steps / multiple epochs, the buffer value (running mean/var) is updated. Since the buffer value is not in the computation graph, the updated buffer value is used during backprop instead of real buffer value, especially when computing gradients of starting steps/epochs.

Also, I do hope they can be updated, otherwise the bn layer is not fully functioning.

Sorry, I am a bit confused with your namings. Could you define properly what you mean by `updated buffer` and `real buffer`and `the buffer value is not in the computation graph`.

Also when you use batchnorm in train mode, the running_mean and running_var are not used, only updated to be used in `.eval()` mode.

I’m sorry, perhaps those namings are not 100% accurate. Also, I’m sorry for your confusion. Perhaps this expression is easier to follow.

The problem is that, dataset distillation algorithm assumes that model weight (and running mean/var) stays unchanged during distillation (multiple running of `forward` and `backward`). But after a run of `forward` and `backward`, the running mean/var is changed, leading to incorrect gradient computation in the next run of `forward` and `backward`. Also, the number of synthesized images is very small (10-100 images), so we cannot simply use the statistics of synthesized images. The naive solution would be using batch norm always in eval mode ( `track_running_stats=False` ), which will harm model power. A more clever solution would be to track and add computation graphs for buffers, and then use gradient to update running mean/var, just like synthesizing images in `backward` function.

And my question is, how to add running mean/var into computation graph while keep the bn layer behavior normal?

So you want to change the batchnorm layer to behave just like a batchnorm in eval mode with the `running_mean` and running_var` being parameters that you train jointly with the rest of the model?

That’s right!
During training, I want the bn layer to use batch mean/var, while `running_mean` and ` running_var` is trained like common model parameters. During evaluation, I want the trained `running_mean` and ` running_var` to be used like common `running_mean` and ` running_var` in common bn layer.
Is there any way to do so?

You will need to subclass batchnorm to make that happen.
Here is an example for the 2d version:

``````Class MyLearntBatchnorm(nn. BatchNorm2d):
def __init__(self, *args, **kwargs):
# Initialize the regular batchnorm
super().__init__(*args, **kwargs)
# Get the size of the runnning_* buffers
stats_size = self.running_mean.size()
# Unregister them as buffer by deleting them
del self.running_mean
del self.running_var
# Make them learnable parameters now
self.running_mean = nn.Parameter(torch.Tensor(stats_size))
self.running_var = nn.Parameter(torch.Tensor(stats_size))
# Initialize them whichever way you want (I don't think uniform is a good idea here but not sure what to use)
self.running_mean.uniform_()
self.running_var.uniform_()

def forward(self, *args, **kwargs):
# Our version behaves the same during train and eval and mimicks the original eval mode
orig_training = self.training
self.training = False
# Use the regular batchnorm forward
super().forward(*args, **kwargs)
# Restore the original training flag
self.training = orig_training
``````

I haven’t tried the code so there might be some typos but that should clearly present the idea.
What can block this approach would be that the gradients for running_mean and running_var are not implemented in the provided batchnorm.
In that case, you will have to reimplement the batchnorm by hand using python functions.

1 Like

So in addition to the batchnorm in your example code, I have to implement a `backward` method for batchnorm, to manually compute the gradients for running_mean and running_var?

You can use this new batchnorm as a drop in replacement.
It it runs fine, then you’re done.

If it complains that it cannot compute gradients for running_mean and running_var, then you will have to change the `forward()` function to implement batchnorm:

``````# This is not the exact code you want, just a quick version to give you an example
def forward(self, input):
out = (input - self.running_mean) / self.running_var
out = # The linear transformation from batchnorm
return out
``````

This will be a bit slower / consume a bit more memory than the fully optimized one but will do what you want

Ah, I see! Thank you, I think that solves my problem.
Also, I’m sorry for any inconvenience or confusion I’ve caused. I don’t know much about how to ask a question properly since I’m new here.
Thanks again!

No problem, happy to help! Glad that this works for you!