Ho to make a scalar/tensor learnable?

Imagine I have a scalar T , this T is gonna be used as a threshold in my network. i.e.
TensorA = torch.where(TensorB > T*Means, Ones, Zeros).
Right now I have T = torch.tensor(1.0), but I want to give it the ability to change and be learnable.
Is the way to do that?
In other words, how can I wrap it in a way to be learnable?

By learnable, are you referring to requires_grad=True ? More details here

Hmmm i dont think that is enough…
I think i need to do something like this:
LearnableParameter = nn.Parameter(Parameter, requires_grad=True)
but im not %100 sure
it would be helpful if someone can confirm it

If you use nn.Parameter it will Be learnable

1 Like