How can I initialize weights for everything in my class except for self.fcn
below ? I could write a nn.init.xavier_uniform_()
for every component but it gets tedious.
class EMBED_MP_MLP(nn.Module):
def __init__(self, args):
super(EMBED_MP_MLP, self).__init__()
# fully convolutional network
self.fcn = ALEXNET_FCN()
self.fcn.load_state_dict(torch.load(args.fcn_restore_path)['classifier_state'])
# keypoints embedder
self.kp_embedder = nn.Linear(34, args.kp_embedding_dim)
# pooling
self.pool = nn.MaxPool2d(kernel_size=(args.min_obs_len[0], 1))
# MLP classifier
mlp_dims = [args.kp_embedding_dim + 121, args.mlp_dim, 2]
self.mlp = make_mlp(
mlp_dims,
activations=["relu",""],
batch_norm=args.batch_norm,
dropout=args.dropout
)