RuntimeError: element 0 of tensors does not require grad and does not have a

I connected a classification network consisting of two linear layers in front of a standard transformer. I wanted to freeze the parameters of the standard transformer and only train the preceding classification layer. So, I traversed the parameters of the entire network, set their require_grad to False, and set the require_grad of the classification layer to True. However, this caused the error as shown in the title.
Interestingly, when I set require_grad to True for either the encoder or decoder, everything returns to normal. And the entire network also works fine when none of the layers are frozen.

below is the classifier network and the whole network

class binaryClassifier(BaseFairseqModel):
    def __init__(self, in_dim, hidden_size=32, out_dim=1, ):
        self.layer1 = nn.Linear(in_dim, hidden_size)
        self.layer2 = nn.Linear(hidden_size, out_dim)
        self.activate = nn.Sigmoid()

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        return self.activate(x)

class binFormerModel(TransformerModel):
    def __init__(self, args, encoder, decoder, in_dim=2, hidden_size=32, out_dim=1):

        super().__init__(args, encoder, decoder)
        self.binaryClassifier = binaryClassifier(in_dim, hidden_size, out_dim)
        for param in self.parameters():
            param.requires_grad = False
        for param in self.binaryClassifier.parameters():
            param.requires_grad = True

below is the forward function of binformer

real_src and real_length are tensors with no gradients.
I’m not sure that pure_source and src_tokens have the same dimension but if yes,

real_src = torch.where(src_tokens, pure_source, pred >= 0.5)

(It could return error since I’m not available to test the code)

thanks for your reply

pred is a tensor (batch_size,1)
source_tokens (batch_size,sequence_length1)
pure_source (batch_size, sequence_length2)
usually sequence_length1 is not equall to length_2

so I have no idea with how to write torch.where

It is weird.
Could real_src[i] = src_tokens[i] work?
It may return error since they have the different dimensions

you help me find out another bug. Since pred is very low at the beginning so the branch you mentioned has never worked indeed.:joy: