Take the following pseudocode, attempting to define a global tunable parameter ‘gamma’ which is incorporated in the loss function. Rough solution is edited below.
class someModel(nn.Module):
def __init__(self, process):
super(someModel, self).__init__()
self.process = SomeNnFunction()
self.gamma = nn.Parameter(torch.ones(1), requires_grad=True)
def forward(self, x):
xhat = self.process(x)
gamma = self.gamma
return xhat, gamma
set up 'loader'
for i, (x, _) in enumerate(loader):
model.zero_grad()
x = x.to(device)
xhat, gamma = model(x)
gamma=gamma.mean()
loss = loss_function(xhat, x, gamma)
loss.backward()
optimizer.step()