Initialize weights except for those that

How can I initialize weights for everything in my class except for self.fcn below ? I could write a nn.init.xavier_uniform_() for every component but it gets tedious.

class EMBED_MP_MLP(nn.Module):
    def __init__(self, args):
        super(EMBED_MP_MLP, self).__init__()
        
        # fully convolutional network
        self.fcn = ALEXNET_FCN()     
        self.fcn.load_state_dict(torch.load(args.fcn_restore_path)['classifier_state'])
        
        # keypoints embedder        
        self.kp_embedder = nn.Linear(34, args.kp_embedding_dim)       
        
        # pooling
        self.pool = nn.MaxPool2d(kernel_size=(args.min_obs_len[0], 1))
        
        # MLP classifier        
        mlp_dims = [args.kp_embedding_dim + 121, args.mlp_dim, 2]
        self.mlp = make_mlp(
                mlp_dims,
                activations=["relu",""],
                batch_norm=args.batch_norm,
                dropout=args.dropout
            )

Maybe not the most beautiful approach, but should get the work done:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc = nn.Linear(10, 10)
        self.non_fc = nn.Linear(1, 1)
        
    def forward(self, x):
        return x

def weight_init(module):
    if isinstance(module, nn.Linear):
        print('initializing layer shape: {}'.format(module.weight.shape))
        nn.init.xavier_normal_(module.weight)

model = MyModel()
[weight_init(m) for name, m in model.named_children() if 'non_fc' not in name]
1 Like