Hello, everyone
I want to make a custom regularization layer with Pytorch but something is wrong to my regularization layer because the loss output is all same when training.
The real problem is that I found out myloss gets same net.parameters() in every training process but, I do not know why it gets same weight parameters even if I give updated network
My custom layer is like below
class bipolar_loss(nn.Module):
def __init__(self, lambd=5e-7):
super(bipolar_loss, self).__init__()
self.lambd = lambd
def forward(self, net):
loss = 0
for param in net.parameters():
# Only for weight parameters
if len(param.size()) == 4:
loss += (1 - torch.pow(param, 2)).sum()
return loss * self.lambd
And I use this layer in main.py like below
criterion = nn.CrossEntropyLoss().cuda()
loss_func = bipolar_loss().cuda()
In the training process, what I want to do is to add two losses like L2 regularization so I add two losses in every training iterations
for batch_idx, (inputs, targets) in enumerate(trainloader): inputs, targets = inputs.to(device), targets.to(device) optimizer.zero_grad() # forward pass to get output / logits outputs = net(inputs) # calculate loss *default is cross entropy loss loss = criterion(outputs, targets) myloss = loss_func(net) loss = loss + myloss # getting gradients parameters loss.backward() # updating parameters optimizer.step() train_loss += loss.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() epoch_loss = train_loss/(batch_idx+1) epoch_acc = 100.*correct/total
However, I got same loss when printing the output of myloss
iteration 0:
loss: tensor(2.2010, device=âcuda:6â, grad_fn=NllLossBackward)
myloss: tensor(2.3100, device=âcuda:6â, grad_fn=MulBackward0)
total_loss: tensor(4.5111, device=âcuda:6â, grad_fn=AddBackward0)
iteration 1:
loss: tensor(2.2096, device=âcuda:6â, grad_fn=NllLossBackward)
myloss: tensor(2.3100, device=âcuda:6â, grad_fn=MulBackward0)
total_loss: tensor(4.5196, device=âcuda:6â, grad_fn=AddBackward0)
myloss is all same in every iterations and epochsâŚ