Initialize weights except for those that

Maybe not the most beautiful approach, but should get the work done:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc = nn.Linear(10, 10)
        self.non_fc = nn.Linear(1, 1)
        
    def forward(self, x):
        return x

def weight_init(module):
    if isinstance(module, nn.Linear):
        print('initializing layer shape: {}'.format(module.weight.shape))
        nn.init.xavier_normal_(module.weight)

model = MyModel()
[weight_init(m) for name, m in model.named_children() if 'non_fc' not in name]
1 Like