For your example (constraining variables to be between 0 and 1), there’s no difference between what you’re suggesting – clipping the gradient update – versus letting that gradient update take place in full and then clipping the weights afterwards. Clipping the weights, however, is much easier than modifying the optimizer.
Here’s a simple example of a UnitNorm clipper:
class UnitNormClipper(object):
def __init__(self, frequency=5):
self.frequency = frequency
def __call__(self, module):
# filter the variables to get the ones you want
if hasattr(module, 'weight'):
w = module.weight.data
w.div_(torch.norm(w, 2, 1).expand_as(w))
Instantiating this with clipper = UnitNormClipper()
, then, after the optimizer.step()
call, do the following:
model.apply(clipper)
Full training loop example:
for epoch in range(nb_epoch):
for batch_idx in range(nb_batches):
xbatch = x[batch_idx*batch_size:(batch_idx+1)*batch_size]
ybatch = y[batch_idx*batch_size:(batch_idx+1)*batch_size]
optimizer.zero_grad()
xp, yp = model(xbatch, ybatch)
loss = model.loss(xp, yp)
loss.backward()
optimizer.step()
if epoch % clipper.frequency == 0:
model.apply(clipper)
A 0-1 clipper might look like this (not tested):
class ZeroOneClipper(object):
def __init__(self, frequency=5):
self.frequency = frequency
def __call__(self, module):
# filter the variables to get the ones you want
if hasattr(module, 'weight'):
w = module.weight.data
w.sub_(torch.min(w)).div_(torch.max(w) - torch.min(w))