I’m implementing the max norm constraint as detailed in this post. Max Norm would constraint the parameters of a given layer based on the L2 norm of the weights.

It’s my understanding that the operations should be done in-place for memory efficiency. I have the following code.

def forward(self, x):
x = x.view(-1, n_inputs)
max_norm_(layer.weight)
x = layer(x)
return x
def max_norm_(self, w):
with torch.no_grad():
norm = w.norm(2, dim=0, keepdim=True)
desired = torch.clamp(norm, 0, self._max_norm_val)
w *= (desired / (self._eps + norm))

Is this the correct way to implementing it? Will this interfere with the optimization step? I’m using torch.no_grad() since we don’t need to get gradients for this operation.

It’s likely not memory efficiency that matters here (you’ll have other ops that use more memory and you could do this after the optimizer step when memory is less precious than after the forward) but that you change a parameter that has you want inplace.
The implementation you propose looks like the one you linked, but note that it will scale the weight by norm/(norm+eps) when the norm is within the your bound. This may or may not be desired. If not, it might be better to just clamp away both norm and desired from 0 or so.

It’s likely not memory efficiency that matters here (you’ll have other ops that use more memory and you could do this after the optimizer step when memory is less precious than after the forward) but that you change a parameter that has you want inplace.

That’s a good point. I don’t need to do that operation in the forward step.

The implementation you propose looks like the one you linked, but note that it will scale the weight by norm/(norm+eps) when the norm is within the your bound. This may or may not be desired. If not, it might be better to just clamp away both norm and desired from 0 or so.

Thanks for catching that. You are right, norm/(norm + eps) is indeed not the desired behavior. When you say clamp away from both norm and desired from 0, do you mean something like w *= torch.clamp(desired / (eps + norm), min=1) where we ensure that when norm is within the bounds, there is no effect?