Freezing in shared weight

class modelA(Module):
    def __init__(self):
        super().__init__()
        self.base= ...
        self.headA= ...
    
    def forward(self,input):
        x= self.base(input)
        outA= self.headA(x)

        return outA

class modelB(Module):
    def __init__(self):
        super().__init__()
        self.base=...
        self.headB= ...

    def forward(self,input):
        x= self.base(input)
        outB= self.headB(x)
        return outB

this model was supposed to work in way that weight of base layer should be shared, so as per suggestion here by @apaszke I defined a combined model as follow

class model_combined(Module):
    def __init__(self):
        super().__init__()
        self.base=...
        self.headA= ...
        self.headB= ...

    def forward(self,input):
        x= self.base(input) # this need to be freezed
        outA= self.headA(x)
        y=self.base(input) #this neeeds to backpropagte
        outB= self.headB(y)
        return outA, outB

in training loop I am calculating loss as

total_loss= loss(outA)+ loss(outB)
total_loss.backward()

but after some epoch, I need to freeze modelA (i.e self.headA in model_combined)
My questions are:

  1. since base is common layer here, so how to freeze headA (weight of headA should not change) such that base is updated due to backpropagation of loss(outB) but not due to loss(outA)
  2. will it be a bad practice if instead of freezing I just don’t backpropagate loss through headA i.e change total_loss= loss(outB)

Hey,

If you don’t want gradients corresponding to headA, then yes I think the best thing to do is to just not backrprop for it.

Note that if you share the weights between A and B, then the weights in A will still change due to the update on B.

Hi, but won’t this be a problem , as graph of headA will keep growing without a loss.backward() step for headA (back-propagation frees the graph created during epoch) ? by sharing of weight do you mean manually copying of weight from one layer to another or something else.

by sharing of weight do you mean manually copying of weight from one layer to another or something else.

You mention that " since base is common layer here" so I assume that you have a single set of parameters for both. But maybe I misunderstood this.