How to apply gradient to tuneable gamma parameter in loss function

Take the following pseudocode, attempting to define a global tunable parameter ‘gamma’ which is incorporated in the loss function. Rough solution is edited below.

class someModel(nn.Module):
    def __init__(self, process):
        super(someModel, self).__init__()
        self.process = SomeNnFunction()
        self.gamma = nn.Parameter(torch.ones(1), requires_grad=True)

    def forward(self, x):
        xhat = self.process(x)
        gamma = self.gamma
        return xhat, gamma

set up 'loader'
for i, (x, _) in enumerate(loader):
        x =
        xhat, gamma = model(x)
        loss = loss_function(xhat, x, gamma)

You can find an PyTorch implementation for such a loss in this thread.

Nevermind, I solved this. Added the rough solution to the question for posterity.