Restrict range of variable during gradient descent

For your example (constraining variables to be between 0 and 1), there’s no difference between what you’re suggesting – clipping the gradient update – versus letting that gradient update take place in full and then clipping the weights afterwards. Clipping the weights, however, is much easier than modifying the optimizer.

Here’s a simple example of a UnitNorm clipper:

class UnitNormClipper(object):

    def __init__(self, frequency=5):
        self.frequency = frequency

    def __call__(self, module):
        # filter the variables to get the ones you want
        if hasattr(module, 'weight'):
            w = module.weight.data
            w.div_(torch.norm(w, 2, 1).expand_as(w))

Instantiating this with clipper = UnitNormClipper(), then, after the optimizer.step() call, do the following:

model.apply(clipper)

Full training loop example:

        for epoch in range(nb_epoch):
            for batch_idx in range(nb_batches):
                xbatch = x[batch_idx*batch_size:(batch_idx+1)*batch_size]
                ybatch = y[batch_idx*batch_size:(batch_idx+1)*batch_size]

                optimizer.zero_grad()
                xp, yp = model(xbatch, ybatch)
                loss = model.loss(xp, yp)
                loss.backward()
                optimizer.step()

            if epoch % clipper.frequency == 0:
                model.apply(clipper)

A 0-1 clipper might look like this (not tested):

class ZeroOneClipper(object):

    def __init__(self, frequency=5):
        self.frequency = frequency

    def __call__(self, module):
        # filter the variables to get the ones you want
        if hasattr(module, 'weight'):
            w = module.weight.data
            w.sub_(torch.min(w)).div_(torch.max(w) - torch.min(w))
8 Likes