How to make weights into leaf node after using an assignment operation on them

I’m adapting a 10-class classification model to an 11-class classification model where I add one extra class. As well as keeping the prior layer weights, I want to make use of as much of the final classifier weights as possible, so I do something like this:

def make_backbone(self, load=''):
    self.backbone = SFNetV1()
    if len(load):
        self.backbone.load(load)

    # replace fully connected layer 
    if len(load):
        prior_weight = self.backbone.fc1.weight
        prior_bias = self.backbone.fc1.bias
    prior_in_features = self.backbone.fc1.in_features
    prior_out_features = self.backbone.fc1.out_features
    
    # add "empty" token
    self.backbone.fc1 = nn.Linear(prior_in_features, prior_out_features+1)

    print(self.backbone.fc1.weight.is_leaf)
    print(prior_weight.is_leaf)

    # reuse whatever weights and biases we still can
    if len(load):
        self.backbone.fc1.weight[:prior_out_features, :] = prior_weight
        self.backbone.fc1.bias[:prior_out_features] = prior_bias

    print(self.backbone.fc1.weight.is_leaf)

This outputs

True
True
False

So my problem is that last False. So how do I “reset” self.backbone.fc1.weight to be a leaf node (also bias)?

Bonus side question: Is there a better way to do what I’m trying to do?

You can wrap the assignment operation into a with torch.no_grad() block to make sure Autograd doesn’t record this copy as a differentiable operation:

fc1 = nn.Linear(10, 10)

prior_weight = fc1.weight
prior_bias = fc1.bias
prior_in_features = fc1.in_features
prior_out_features = fc1.out_features

# add "empty" token
fc2 = nn.Linear(prior_in_features, prior_out_features+1)

print(fc1.weight.is_leaf)
print(prior_weight.is_leaf)

# reuse whatever weights and biases we still can
with torch.no_grad():
    fc2.weight[:prior_out_features, :] = prior_weight
    fc2.bias[:prior_out_features] = prior_bias

print(fc2.weight.is_leaf)

out = fc2(torch.randn(1, 10))
print(out.shape)
out.mean().backward()
print(fc2.weight.grad)
1 Like