[Input -> Encoder -> Classifier
-> Classifier2]
def freeze(...) # defined to freeze the network
def unfreeze(...) # defined to unfreeze the network
# in train-loop:
freeze(encoder)
unfreeze(classifier)
enc_out = encoder(input)
class_out = classifier(enc_out)
class_loss = ce(class_out, target)
# -> trained only classifier.
unfreeze(encoder)
unfreeze(classifier2)
class_out2 = classifier2(enc_out) # This part !!!
class_loss2 = ce(class_out2, target)
...
If I use enc_out
when training classifier2
, does the encoder also trained?
Or do I have to do enc_out = encoder(input)
again after unfreeze(encoder)
?