This what I do:
class YourModel(nn.Module):
def __init__(self,x,y):
super(MF, self).__init__()
pass
def init_weights(self):
pass
def regularization(self):
return your regularization term
# example: return torch.sum(torch.pow(self.x.weight, 2)) + torch.sum(torch.pow(self.y.weight, 2))
def forward(self):
pass
and for Loss:
loss = criterion(pred, target) + lambda*(model.regularization())