How to add trainable attributes outside the nn.Module entities?

I have reviewed the forums and docs about adding a trainable parameter to the model, by defining trainable_attr = nn.Parameter(None, requires_grad = True) and after that register it to the target model model.register_parameter(name, params). But How can I train a parameter outside of the model, for instance, I have a custom loss which is the addition of two losses, e.g., MSE and Cross_Entropy:
total_loss = mse_loss *gamma + cross_entroppy_loss *(1-gamma)
How can I define and train gamma?


you can create a separate nn.Module which contains your loss functions and your learnable parameter, just like this:

class CostumLoss(nn.Module):
    def __init__(self):
        self.mse_fn = nn.MSELoss()
        self.ce_fn = nn.CrossEntropyLoss()
        self.gamma = nn.Parameter(torch.tensor([.5]))

    def forward(self, out, mse_target, ce_target):
        mse_loss = self.mse_fn(out, mse_target)
        ce_loss = self.ce_fn(out, ce_target)
        loss = mse_loss * self.gamma + ce_loss * (1 - self.gamma)
        return loss

For the optimization you can ether use a separate optimizer for your loss function or chain it with the model parameters:

import itertools

loss_fn = CostumLoss()
optim = optim.Adam(itertools.chain(model.parameters(), loss_fn.parameters()), lr=1e-3)

I think you could also declare gamma as a normal variable of type nn.Parameter outside of any Module, but i don’t know how you would ‘chain’ it with the other parameters, like i did above.


Thank you. It’s working without any error, but the model can not be trained, i.e., the loss doesn’t decrease. In your code snippet, how do you involve the input x? I have this same question about your second solution as well, i.e, how can we make the chain between the input and our trainable gamma in this solution?


sorry I did a mistake, but its now corrected…so x was out. I don’t even know if this loss function would work tbh, because gamma is not required to be in range [0, 1] and could just shoot toward +/- inf.

For the second solution I don’t know how to correctly implement it myself…so you would write something like:

gamma = nn.Parameter(torch.tensor([.5]))
optim = optim.Adam(itertools.chain(model.parameters(), gamma), lr=1e-3)

but if you run this, it will give you an error, because iterating over gamma returns you a non-leaf tensor, see:

for g in gamma:

Hopefully someone can point out my error :smiley: