Parameter.grad of conv weight is none after Virtual Batch Normalization

Hi all!

I have an implementation of Virtual Batch Normalization (VBN) that I’m using on each convolution of a discriminator network similar to DCGAN, more specifically, this is the discriminator network on the SEGAN paper.

Let’s assume disc is the state of the Discriminator after after a forward pass and that d_loss is the loss of disc after the same forward pass.
We then back-propagate the loss and take a peek at the norm of the gradients to find out that p.grad is equal to None

The code works normally if I remove the VBN after each conv layeror OR if in VBN’s forward method I set x.data = out and return x.
Any help is appreciated!!!

            d_loss.backward()
            d_grad_norm = 0
            for k, p in disc.named_parameters():
                d_grad_norm += p.grad.data.norm()
class Discriminator(nn.Module):
    def __init__(self, ndf, kernel_size):
        super(Discriminator, self).__init__()
            self.encoder = nn.ModuleList([
            # 16
            nn.Conv1d(1, ndf, kernel_size, 2, 15, 1)
            VBN(ndf),
            # 32
            nn.Conv1d(ndf, ndf * 2, kernel_size, 2, 15, 1),
            VBN(ndf*2),
            # more convolutions... and finally
            # Linear for output
            nn.Linear(8, 1)
        ])
class VBN(Module):
    """
    Virtual Batch Normalization
    """

    def __init__(self, n_features, epsilon=1e-5):
        super(VBN, self).__init__()
        assert isinstance(epsilon, float)

        # batch statistics
        self.epsilon = epsilon
        self.mean = torch.zeros(1, n_features, 1)
        self.mean_sq = torch.zeros(1, n_features, 1)
        self.batch_size = None
        # reference output
        self.reference_output = None

    def initialize(self, x):
        # compute batch statistics
        # self.mean = torch.mean(x, [0, 2], keepdim=True)
        # self.mean_sq = torch.mean(x**2, [0, 2], keepdim=True)
        self.mean = x.data.mean(2).mean(0).resize_(1, x.size(1), 1)
        self.mean_sq = (x.data**2).mean(2).mean(0).resize_(1, x.size(1), 1)
        self.batch_size = x.size(0)
        assert x is not None
        assert self.mean is not None
        assert self.mean_sq is not None
        # compute reference output
        out = self._normalize(x, self.mean, self.mean_sq)
        self.reference_output = out

    def forward(self, x):
        if self.reference_output is None:
            self.initialize(x)
        new_coeff = 1. / (self.batch_size + 1.)
        old_coeff = 1. - new_coeff
        # new_mean = torch.mean(x, [0, 2], keep_dims=True)
        # new_mean_sq = torch.mean(x**2, [0, 2], keep_dims=True)
        new_mean = x.data.mean(2).mean(0).resize_as_(self.mean)
        new_mean_sq = (x.data**2).mean(2).mean(0).resize_as_(self.mean_sq)
        mean = new_coeff * new_mean + old_coeff * self.mean
        mean_sq = new_coeff * new_mean_sq + old_coeff * self.mean_sq
        out = self._normalize(x, mean, mean_sq)
        return Variable(out)

    def _normalize(self, x, mean, mean_sq):
        assert self.epsilon is not None
        assert mean_sq is not None
        assert mean is not None
        assert len(x.size()) == 3
        gamma = torch.normal(means=torch.ones(1, x.size(1), 1), std=0.02)
        gamma = gamma.float().cuda(async=True)
        beta = torch.cuda.FloatTensor(1, x.size(1), 1).fill_(0)
        std = torch.sqrt(self.epsilon + mean_sq - mean**2)
        out = x.data - mean
        out = out / std
        out = out * gamma
        out = out + beta
        return out

    def __repr__(self):
        return ('{name}({num_features}, eps={eps}, mean={mean}, mean_sq={mean_sq}'.format(
            name=self.__class__.__name__, **self.__dict__))

In your VBN implementation’s forward, you are directly operating on the tensor within the Variable x, ie you are using x.data, so you lose all dynamic graph building there. Hence no grad backprop. Just do things on x, don’t extract x.data.

I think I had what SimonW suggested before, with gamma and beta set as Variables otherwise I wouldn’t be able to do the operation. If I remember it correctly, this setup failed during the backwards step.

You should definitely directly work with Variable rather than tensors. If you can post your Variable version code, I can take a look :slight_smile:

Thank you! With the code below I get the following error:

RuntimeError: Trying to backward through the graph a second time, 
but the buffers have already been freed. Specify retain_graph=True 
when calling backward the first time.     
    def initialize(self, x):
        # compute batch statistics
        self.mean = x.mean(2).mean(0).resize(1, x.size(1), 1)
        self.mean_sq = (x**2).mean(2).mean(0).resize(1, x.size(1), 1)
        self.batch_size = x.size(0)
        assert x is not None
        assert self.mean is not None
        assert self.mean_sq is not None
        # compute reference output
        out = self._normalize(x, self.mean, self.mean_sq)
        self.reference_output = out

    def forward(self, x):
        if self.reference_output is None:
            self.initialize(x)
        new_coeff = 1. / (self.batch_size + 1.)
        old_coeff = 1. - new_coeff
        new_mean = x.mean(2).mean(0).resize_as(self.mean)
        new_mean_sq = (x**2).mean(2).mean(0).resize_as(self.mean_sq)
        mean = new_coeff * new_mean + old_coeff * self.mean
        mean_sq = new_coeff * new_mean_sq + old_coeff * self.mean_sq
        x = self._normalize(x, mean, mean_sq)
        return x

    def _normalize(self, x, mean, mean_sq):
        assert self.eps is not None
        assert mean_sq is not None
        assert mean is not None
        assert len(x.size()) == 3
        gamma = Variable(torch.Tensor(1, x.size(1), 1).cuda().normal_(1., 0.02))
        beta = Variable(torch.cuda.FloatTensor(1, x.size(1), 1).fill_(0))
        std = torch.sqrt(self.eps + mean_sq - mean**2)
        x = x - mean
        x = x / std
        x = x * gamma
        x = x + beta
        return x

Your self.mean, self.mean_sq are also backed by dynamically graph. When you do forward on different x, the self.mean, self.mean_sq are the same. Then when you backward from the output of forward, it will backprop through the graph to comput self.mean, self.mean_sq multiple times.

I don’t know what the intended behavior is. Do you want to backprop through those variables or not? AFAIK, in vanilla BN, it is important to backprop through batch statistics so that nothing is hidden from the optimizers. But here, it seems that you are using the same statistics for different input batch x.

From what I understand, the proposition of virtual batch norm is to use the statistics of a held-out batch for different input batch x.
I’m porting this tensorflow code:

I see. So from my understanding, it also needs to backward through the held out batch’s statistics. This can be a bit tricky since you want to maintain part of the graph. Here is what I would do:

    def initialize(self, x):
        # compute batch statistics
        mean = x.mean(2).mean(0).resize(1, x.size(1), 1)
        mean_sq = (x**2).mean(2).mean(0).resize(1, x.size(1), 1)
        self.batch_size = x.size(0)
        assert x is not None
        assert mean is not None
        assert mean_sq is not None
        # build detached variables to avoid backprop to graph to compute mean and mean_sq
        # we will manually backprop those in hooks
        self.mean = autograd.Variable(mean.data.clone(), requires_grad = True)  # new code
        self.mean_sq = autograd.Variable(mean_sq.data.clone(), requires_grad = True)  # new code
        self.mean.register_hook(lambda grad: mean.backward(grad, retain_graph = True))  # new code
        self.mean_sq.register_hook(lambda grad: mean_sq.backward(grad, retain_graph = True))  # new code
        # compute reference output
        out = self._normalize(x, mean, mean_sq)
        self.reference_output = out.detach_()  # change, just to remove unnecessary saved graph

    def forward(self, x):
        if self.reference_output is None:
            self.initialize(x)
        new_coeff = 1. / (self.batch_size + 1.)
        old_coeff = 1. - new_coeff
        new_mean = x.mean(2).mean(0).resize_as()
        new_mean_sq = (x**2).mean(2).mean(0).resize_as(self.mean_sq)
        mean = new_coeff * new_mean + old_coeff * self.mean
        mean_sq = new_coeff * new_mean_sq + old_coeff * self.mean_sq
        x = self._normalize(x, mean, mean_sq)
        return x

    def _normalize(self, x, mean, mean_sq):
        assert self.eps is not None
        assert mean_sq is not None
        assert mean is not None
        assert len(x.size()) == 3
        gamma = Variable(torch.Tensor(1, x.size(1), 1).cuda().normal_(1., 0.02))
        beta = Variable(torch.cuda.FloatTensor(1, x.size(1), 1).fill_(0))
        std = torch.sqrt(self.eps + mean_sq - mean**2)
        x = x - mean
        x = x / std
        x = x * gamma
        x = x + beta
        return x

I didn’t test it, so let me know if there are more issues.

Thanks for posting this but there are still issues: running this code, the execution of d_loss.backard() doesn’t proceed to the next line of code even after 4 minutes.

I looked at the source code of batch norm and there are no register hooks for computing grads, at least not in the batchnorm class

BN and VBN are different.

In BN, the graph of each batch is only built by that batch. So you can forward (building graph) and backward (and then throwing away that graph).

In VBN, you have part of the graph built by a reference batch. More importantly, you want to retain that graph part for backward using different inputs. However, backward will automatically go back to the leave variables. So I used dummy detached ref batch statistics, and backward through the original ones afterwards.

That said, it is interesting that the code hangs. Looking at it, I can’t think of a reason. Let me try it. How is VBN used in your code?

edit: oh I see it’s in discriminator.

Got it. I’m looking at the graph the tensorflow code and it seems that only gamma and beta are going into the optimizer.
VBN is used in the Discriminator after each convolution and before the non-linearities.
Let me know if you need more info about how VBN is used in my code.

Can you try this? It’s not the most efficient but it should be better.

    def initialize(self, x):
        # compute batch statistics
        mean = x.mean(2).mean(0).resize(1, x.size(1), 1)
        mean_sq = (x**2).mean(2).mean(0).resize(1, x.size(1), 1)
        self.batch_size = x.size(0)
        assert x is not None
        assert mean is not None
        assert mean_sq is not None
        # build detached variables to avoid backprop to graph to compute mean and mean_sq
        # we will manually backprop those in hooks
        self.mean = autograd.Variable(mean.data.clone(), requires_grad = True)  # new code
        self.mean_sq = autograd.Variable(mean_sq.data.clone(), requires_grad = True)  # new code
        self.mean.register_hook(lambda grad: mean.backward(grad, retain_graph = True))  # new code
        self.mean_sq.register_hook(lambda grad: mean_sq.backward(grad, retain_graph = True))  # new code
        # compute reference output
        out = self._normalize(x, mean, mean_sq)
        self.reference_output = out.detach_()  # change, just to remove unnecessary saved graph
        return mean, mean_sq
   
    def get_ref_batch_stats(self):
        return self.mean, self.mean_sq

    def forward(self, x):
        if self.reference_output is None:
            ref_mean, ref_mean_sq = self.initialize(x)  # here
        else:  # here
            ref_mean, ref_mean_sq = self.get_ref_batch_stats()  # here
        new_coeff = 1. / (self.batch_size + 1.)
        old_coeff = 1. - new_coeff
        new_mean = x.mean(2).mean(0).resize_as()
        new_mean_sq = (x**2).mean(2).mean(0).resize_as(ref_mean_sq)  # here
        mean = new_coeff * new_mean + old_coeff * ref_mean  # here
        mean_sq = new_coeff * new_mean_sq + old_coeff * ref_mean_sq  # here
        x = self._normalize(x, mean, mean_sq)
        return x

    def _normalize(self, x, mean, mean_sq):
        assert self.eps is not None
        assert mean_sq is not None
        assert mean is not None
        assert len(x.size()) == 3
        gamma = Variable(torch.Tensor(1, x.size(1), 1).cuda().normal_(1., 0.02))
        beta = Variable(torch.cuda.FloatTensor(1, x.size(1), 1).fill_(0))
        std = torch.sqrt(self.eps + mean_sq - mean**2)
        x = x - mean
        x = x / std
        x = x * gamma
        x = x + beta
        return x

Additionally, you should add gamma and beta as parameters for them to be optimized. They are currently not.

Yes, it runs now and I’ve set gamma and beta as Parameters!
Thanks a lot for your help! Is this you http://ssnl.github.io/about/?

import torch
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from torch.nn.modules import Module
import pdb


class VBN(Module):
    """
    Virtual Batch Normalization
    """

    def __init__(self, num_features, eps=1e-5):
        super(VBN, self).__init__()
        assert isinstance(eps, float)

        # batch statistics
        self.num_features = num_features
        self.eps = eps
        self.mean = torch.zeros(1, num_features, 1)
        self.mean_sq = torch.zeros(1, num_features, 1)
        self.batch_size = None
        # reference output
        self.reference_output = None
        gamma = torch.normal(means=torch.ones(1, num_features, 1), std=0.02)
        self.gamma = Parameter(gamma.float().cuda(async=True))
        self.beta = Parameter(torch.cuda.FloatTensor(1, num_features, 1).fill_(0))

    def initialize(self, x):
        # compute batch statistics
        mean = x.mean(2).mean(0).resize(1, x.size(1), 1)
        mean_sq = (x**2).mean(2).mean(0).resize(1, x.size(1), 1)
        self.batch_size = x.size(0)
        assert x is not None
        assert mean is not None
        assert mean_sq is not None
        # build detached variables to avoid backprop to graph to compute mean and mean_sq
        # we will manually backprop those in hooks
        self.mean = Variable(mean.data.clone(), requires_grad = True)  # new code
        self.mean_sq = Variable(mean_sq.data.clone(), requires_grad = True)  # new code
        self.mean.register_hook(lambda grad: mean.backward(grad, retain_graph = True))  # new code
        self.mean_sq.register_hook(lambda grad: mean_sq.backward(grad, retain_graph = True))  # new code
        # compute reference output
        out = self._normalize(x, mean, mean_sq)
        self.reference_output = out.detach_()  # change, just to remove unnecessary saved graph
        return mean, mean_sq

    def get_ref_batch_stats(self):
        return self.mean, self.mean_sq

    def forward(self, x):
        if self.reference_output is None:
            ref_mean, ref_mean_sq = self.initialize(x)
        else:
            ref_mean, ref_mean_sq = self.get_ref_batch_stats()
        new_coeff = 1. / (self.batch_size + 1.)
        old_coeff = 1. - new_coeff
        new_mean = x.mean(2).mean(0).resize_as(self.mean)
        new_mean_sq = (x**2).mean(2).mean(0).resize_as(self.mean_sq)
        mean = new_coeff * new_mean + old_coeff * ref_mean  # change
        mean_sq = new_coeff * new_mean_sq + old_coeff * ref_mean_sq  # change
        x = self._normalize(x, mean, mean_sq)
        return x

    def _normalize(self, x, mean, mean_sq):
        assert self.eps is not None
        assert mean_sq is not None
        assert mean is not None
        assert len(x.size()) == 3
        #gamma = Variable(torch.Tensor(1, x.size(1), 1).cuda().normal_(1., 0.02))
        #beta = Variable(torch.cuda.FloatTensor(1, x.size(1), 1).fill_(0))
        std = torch.sqrt(self.eps + mean_sq - mean**2)
        x = x - mean
        x = x / std
        x = x * self.gamma
        x = x + self.beta
        return x
    def __repr__(self):
        return ('{name}(num_features={num_features}, eps={eps}'.format(
            name=self.__class__.__name__, **self.__dict__))

The code above still has some efficiency issues. Specifically, it backprops each layer’s ref batch statistics independently. It should benefit from backprop those all together, as they share computation graph parts.

If you want to do that, remove the backward hooks, and after you backward from d_loss, collect a list of grads on the dummy variables, and use `autograd.backward’ (http://pytorch.org/docs/0.2.0/autograd.html#torch.autograd.backward) to backward on all ref stats with the grad list. It’s not the most elegant way :slight_smile:. You might be able to figure out a better approach.

Yes that is me haha. It is a bit outdated though.

If we run d_loss.backward() wouldn’t the error below occur again?

RuntimeError: Trying to backward through the graph a second time, 
but the buffers have already been freed. Specify retain_graph=True 
when calling backward the first time.  ```

Since we already introduced these dummy variables, it won’t happen :). But that is a more complicated approach indeed. It will be easy to write bugs.

Right! I sent you an e-mail on simon.l berkeley e-mail!
Thanks again.

1 Like

Actually, forget everything I’ve said. I reread the original paper on improved gan training. With VBN, you want to fix a training batch, rather than fix the statistics. The above approach is problematic:

The layer stats is not computed every time, so it is outdated in terms of parameters. Not only it is the wrong stats (wrt VBN definition), we are also backproping wrong gradients through the saved graph.

Therefore, what you need to do is to specifically save a reference batch (if i understand correctly, the improved gan code also has this, just not in the VBN class). Then each activation, you need to do it twice. First on reference batch and then on the input batch. This is aligned with the original paper.

Sorry for incorrect answers from yesterday!

No sweat! This makes me think that the implementation of virtual batch normalization on the SEGAN github is wrong although it trains the network properly. Can you confirm? I don’t see anywhere multiple forward-passes on the reference batch!

On 1 they define the VBN class. On 2 they define the graph through which the hidden state of the layer goes through the VBN layer before the non-linearity. On 3 they instantiate the VBN class if it’s the first batch seen otherwise they call VBN on the current hidden state.

  1. https://github.com/santi-pdp/segan/blob/master/bnorm.py

  2. https://github.com/santi-pdp/segan/blob/1774069184e3e9666170503a201f46969754459e/discriminator.py#L47

  3. https://github.com/santi-pdp/segan/blob/b2463fc5172f6ac1fcb0e89313be953c2dc0138e/model.py#L297

My tf skill is quite rough… Here’s the relevant part in improved gan I believe: https://github.com/openai/improved-gan/blob/master/imagenet/build_model.py#L60

Maybe it’s here: https://github.com/santi-pdp/segan/blob/master/model.py#L184