How to train a bias threshold for ReLU?

For the Relu layer, I would like to put a bias in it ( i.e. ReLU(x + b) ) and train it as a threshold. But it doesn’t seem to work when I train it. The bias doesn’t change when I update the params. Thanks in advance for your help.

Here is the code:

class Net(nn.Module):
    
    def __init__(self):
        super(Net, self).__init__()
        # Bert
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        
        # ReLU
        self.bias = nn.Parameter(torch.zeros(1), requires_grad=True)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, ids_query, mask_query, ids_context, segment_ids_context, mask_context):
        # get embeddings
        query_emb = self.bert(ids_query, attention_mask=mask_query).last_hidden_state
        context_emb = self.bert(ids_context, token_type_ids=segment_ids_context, attention_mask=mask_context).last_hidden_state
        print(f"Shape_query: {query_emb.shape}, Shape_context: {context_emb.shape}")
        
        # multiply matrix
        out = torch.bmm(query_emb, torch.transpose(context_emb, 1, 2))
        print("Multiplication: ", out.shape)
        
        # max
        out, _ = torch.max(out, dim=2)
        print("Max-Pool: ", out.shape)
        
        # add bias => out_1 + b
        out = out + self.bias
        
        # relu => ReLU(out_1 + b)
        out = self.relu(out)
        print("ReLU: ", out.shape)
        
        # log
        out = torch.log(out+1)
        print("Log: ", out.shape)
        
        # summation
        out = torch.sum(out, dim=1)
        print("Summation: ", out.shape)
        print("========" * 12)
        
        return out

Your approach should work and self.bias is correctly defined as a parameter.
Here is a minimal code snippet, which shows that self.bias gets a valid gradient:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # ReLU
        self.bias = nn.Parameter(torch.zeros(1), requires_grad=True)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        x = x + self.bias
        x = self.relu(x)        
        return x

model = Net()
x = torch.randn(1, 1)
out = model(x)
out.mean().backward()
print(model.bias.grad)
> tensor([1.])

Depending on your use case and model, the gradient might be small and thus the bias wouldn’t be changed by a lot so you could check its .grad value.

Thanks for your help.
Yes, it actually changed, but not apparently to see. Checked with .grad works well.