How to track and add autograd computation graphs for buffers

Hi!
I’m trying to run a dataset distillation algorithm (see paper here) and I’ve encountered an implementation problem. In case you’re not familiar with this paper, I will expalin its main idea briefly. Basically, dataset distillation aims at synthesizing a small number of data on which a model is trained can achieve performance of model trained on original data. And this synthesizing procedure is done by 2nd order derivatives. Note that practice training on synthesized data can be extended to multiple steps and epochs.

Below is a minimal example of how to synthesize data

# rdata and rlabel denote the real data and real label
# steps contain  synthesized data, label and its corresponding lr
def forward(model, rdata, rlabel, steps):
        model.train()
        w = model.get_param()
        params = [w]
        gws = []

        # first we train model on synthesized data
        for step_i, (data, label, lr) in enumerate(steps):
            with torch.enable_grad():
                output = model.forward_with_param(data, w)
                loss = F.cross_entropy(output, label)
            gw, = torch.autograd.grad(loss, w, lr, create_graph=True)

            with torch.no_grad():
                new_w = w.sub(gw).requires_grad_()
                params.append(new_w)
                gws.append(gw)
                w = new_w

        # final L, evaluated on real data
        model.eval()
        output = model.forward_with_param(rdata, params[-1])
        ll = F.cross_entropy(output, rlabel)
        return (ll, params, gws)

def backward(model, rdata, rlabel, steps, saved_for_backward):
        l, params, gws = saved_for_backward
        gdatas = []
        glrs = []
        # d denotes derivative
        dw, = torch.autograd.grad(l, (params[-1],))

        # backward
        model.train()
        for (data, label, lr), w, gw in reversed(list(zip(steps, params, gws))):
            hvp_in = [w]
            hvp_in.append(data)
            hvp_in.append(lr)
            dgw = dw.neg()  # gw is already weighted by lr, so simple negation
            hvp_grad = torch.autograd.grad(
                outputs=(gw,),
                inputs=hvp_in,
                grad_outputs=(dgw,)
            )
            # Update for next iteration, i.e., previous step
            with torch.no_grad():
                # Save the computed gdata and glrs
                gdatas.append(hvp_grad[1])
                glrs.append(hvp_grad[2])

                # Update dw
                # dw becomes the gradients w.r.t. the updated w for previous step
                dw.add_(hvp_grad[0])

        return gdatas, glrs

saved = forward(model, rdata, rlabel, steps)
grad_infos = backward(model, rdata, rlabel, steps, saved)

So here is my problem,during synthesizing multiple steps and multiple epochs data, those buffer( e.g. running mean/var in bn) should be tracked and added into the computation graph. Or it should stay a constant. Otherwise the backprop computation will be in correct since wrong buffer will be used.

My question is, how to add those buffers into computation graph while keep the bn layer normal? Or should I restore buffers of each step and epoch at the appropriate time?

Thanks a lot!

Edit: in case this example is too complex to understand, the algorithm is presented as follows:

Hi,

I’m not sure what you mean with “how to add those buffers into computation graph”.
Do you want them to have a fixed value and never be updated?

I mean, if the model is trained on synthesized data with multiple steps / multiple epochs, and we just backprop like this forward and backward example, then the gradient computation would be incorrect. Because after multiple steps / multiple epochs, the buffer value (running mean/var) is updated. Since the buffer value is not in the computation graph, the updated buffer value is used during backprop instead of real buffer value, especially when computing gradients of starting steps/epochs.

Also, I do hope they can be updated, otherwise the bn layer is not fully functioning.

Sorry, I am a bit confused with your namings. Could you define properly what you mean by updated buffer and real bufferand the buffer value is not in the computation graph.

Also when you use batchnorm in train mode, the running_mean and running_var are not used, only updated to be used in .eval() mode.

I’m sorry, perhaps those namings are not 100% accurate. Also, I’m sorry for your confusion. Perhaps this expression is easier to follow.

The problem is that, dataset distillation algorithm assumes that model weight (and running mean/var) stays unchanged during distillation (multiple running of forward and backward). But after a run of forward and backward, the running mean/var is changed, leading to incorrect gradient computation in the next run of forward and backward. Also, the number of synthesized images is very small (10-100 images), so we cannot simply use the statistics of synthesized images. The naive solution would be using batch norm always in eval mode ( track_running_stats=False ), which will harm model power. A more clever solution would be to track and add computation graphs for buffers, and then use gradient to update running mean/var, just like synthesizing images in backward function.

And my question is, how to add running mean/var into computation graph while keep the bn layer behavior normal?

So you want to change the batchnorm layer to behave just like a batchnorm in eval mode with the running_mean and running_var` being parameters that you train jointly with the rest of the model?

That’s right!
During training, I want the bn layer to use batch mean/var, while running_mean and running_var is trained like common model parameters. During evaluation, I want the trained running_mean and running_var to be used like common running_mean and running_var in common bn layer.
Is there any way to do so?

You will need to subclass batchnorm to make that happen.
Here is an example for the 2d version:

Class MyLearntBatchnorm(nn. BatchNorm2d):
  def __init__(self, *args, **kwargs):
    # Initialize the regular batchnorm
    super().__init__(*args, **kwargs)
    # Get the size of the runnning_* buffers
    stats_size = self.running_mean.size()
    # Unregister them as buffer by deleting them
    del self.running_mean
    del self.running_var
    # Make them learnable parameters now
    self.running_mean = nn.Parameter(torch.Tensor(stats_size))
    self.running_var = nn.Parameter(torch.Tensor(stats_size))
    # Initialize them whichever way you want (I don't think uniform is a good idea here but not sure what to use)
    with torch.no_grad():
      self.running_mean.uniform_()
      self.running_var.uniform_()

  def forward(self, *args, **kwargs):
    # Our version behaves the same during train and eval and mimicks the original eval mode
    orig_training = self.training
    self.training = False
    # Use the regular batchnorm forward
    super().forward(*args, **kwargs)
    # Restore the original training flag
    self.training = orig_training

I haven’t tried the code so there might be some typos but that should clearly present the idea.
What can block this approach would be that the gradients for running_mean and running_var are not implemented in the provided batchnorm.
In that case, you will have to reimplement the batchnorm by hand using python functions.

1 Like

So in addition to the batchnorm in your example code, I have to implement a backward method for batchnorm, to manually compute the gradients for running_mean and running_var?

You can use this new batchnorm as a drop in replacement.
It it runs fine, then you’re done.

If it complains that it cannot compute gradients for running_mean and running_var, then you will have to change the forward() function to implement batchnorm:

# This is not the exact code you want, just a quick version to give you an example
def forward(self, input):
    out = (input - self.running_mean) / self.running_var
    out = # The linear transformation from batchnorm
    return out

This will be a bit slower / consume a bit more memory than the fully optimized one but will do what you want :slight_smile:

Ah, I see! Thank you, I think that solves my problem.
Also, I’m sorry for any inconvenience or confusion I’ve caused. I don’t know much about how to ask a question properly since I’m new here.
Thanks again!

No problem, happy to help! Glad that this works for you!