How to correctly implement in-place Max Norm constraint?

I’m implementing the max norm constraint as detailed in this post. Max Norm would constraint the parameters of a given layer based on the L2 norm of the weights.

It’s my understanding that the operations should be done in-place for memory efficiency. I have the following code.

def forward(self, x):
    x = x.view(-1, n_inputs)
    max_norm_(layer.weight)
    x = layer(x)
return x


def max_norm_(self, w):
    with torch.no_grad():
        norm = w.norm(2, dim=0, keepdim=True)
        desired = torch.clamp(norm, 0, self._max_norm_val)
        w *= (desired / (self._eps + norm))

Is this the correct way to implementing it? Will this interfere with the optimization step? I’m using torch.no_grad() since we don’t need to get gradients for this operation.

Thanks

It’s likely not memory efficiency that matters here (you’ll have other ops that use more memory and you could do this after the optimizer step when memory is less precious than after the forward) but that you change a parameter that has you want inplace.
The implementation you propose looks like the one you linked, but note that it will scale the weight by norm/(norm+eps) when the norm is within the your bound. This may or may not be desired. If not, it might be better to just clamp away both norm and desired from 0 or so.

Best regards

Thomas

2 Likes

It’s likely not memory efficiency that matters here (you’ll have other ops that use more memory and you could do this after the optimizer step when memory is less precious than after the forward) but that you change a parameter that has you want inplace.

That’s a good point. I don’t need to do that operation in the forward step.

The implementation you propose looks like the one you linked, but note that it will scale the weight by norm/(norm+eps) when the norm is within the your bound. This may or may not be desired. If not, it might be better to just clamp away both norm and desired from 0 or so.

Thanks for catching that. You are right, norm/(norm + eps) is indeed not the desired behavior. When you say clamp away from both norm and desired from 0, do you mean something like w *= torch.clamp(desired / (eps + norm), min=1) where we ensure that when norm is within the bounds, there is no effect?

Nevermind, I get you mean. Just make sure norm or desired are not close to zero.

Thanks, very helpful comments! Appreciate it!

Just to make it explicit (you noticed, but for the next person looking here), clamping the quotient won’t work, but using

    with torch.no_grad():
        norm = w.norm(2, dim=0, keepdim=True).clamp(min=self._max_norm_val / 2)
        desired = torch.clamp(norm, max=self._max_norm_val)
        w *= (desired / norm)

should always work as expected and leave small norms alone.

2 Likes