Network Surgery/Modification

I am trying to write a pass that automates the application of a per-module parameter change across a complete model. An example of this is say I get a model as input and I wrap all the linear and conv layers with weight_norm if they dont have it already. Reading through the docs I could think of two ways to do it:

function to apply weight_norm on a module if it hasn’t already

def applyWeightNorm(module):
if re.findall(“Conv2d”, type(module)) and hasattr(module, “weight”):
weight_norm(module)
#end applyWeightNorm

Two ways to apply weight_norm to a module

model.apply(applyWeightNorm)

or

for k,v in model.named_modules():applyWeightNorm(v)

Does this way of doing it makes sense? Is one of these ways better than the other?

I am running in a multi-gpu distributed mode if that matters.

You can apply the first approach.
However, I would change the condition to:

if isinstance(m, nn.Conv2d):
    weight_norm(m.weight.data)

Would this work for you or what are you exactly doing in weight_norm?

I am just using the standard pytorch weight_norm functionality (https://pytorch.org/docs/master/_modules/torch/nn/utils/weight_norm.html). It takes module as input so not sure m.weight.data will work. But I do like the is instance(…) better than my re.findall hack.

Thanks for that.

Ah ok, sorry it sounded like it was an own implementation.

This should work then:

def apply_weight_norm(m):
    if isinstance(m, nn.Linear):
        m = torch.nn.utils.weight_norm(m, 'weight')
        

model = nn.Sequential(nn.Linear(10, 1),
                      nn.ReLU())

model.apply(apply_weight_norm)
print(model[0].weight_g.shape)