Maybe not the most beautiful approach, but should get the work done:
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc = nn.Linear(10, 10)
self.non_fc = nn.Linear(1, 1)
def forward(self, x):
return x
def weight_init(module):
if isinstance(module, nn.Linear):
print('initializing layer shape: {}'.format(module.weight.shape))
nn.init.xavier_normal_(module.weight)
model = MyModel()
[weight_init(m) for name, m in model.named_children() if 'non_fc' not in name]