I have added a regularization term, which adds spectral norm to the loss term. I am using Pytorch 4 version.
On the Imagenet dataset, the Loss is increasing exponentially and then giving ‘nan’ values as shown below.
Processing | | (1/5005) Data: 3.821s | Batch: 29.156s | Total: 0:00:29 | ETA: 0:00:00 | Loss: 25.6833 | t op1: 0.0000
Processing | | (2/5005) Data: 0.001s | Batch: 1.052s | Total: 0:00:30 | ETA: 1 day, 16:31:40 | Loss: 29.6706 | t op1: 0.0000
Processing | | (3/5005) Data: 0.010s | Batch: 0.652s | Total: 0:00:30 | ETA: 20:59:50 | Loss: 29.0471 | t op1: 0.0000
Processing | | (4/5005) Data: 0.011s | Batch: 0.645s | Total: 0:00:31 | ETA: 14:17:51 | Loss: 58.7433 | t op1: 0.0977
Processing | | (5/5005) Data: 0.010s | Batch: 0.656s | Total: 0:00:32 | ETA: 10:56:43 | Loss: 53.8333 | t op1: 0.1562
Processing | | (6/5005) Data: 0.010s | Batch: 0.633s | Total: 0:00:32 | ETA: 8:56:12 | Loss: 53.3468 | t op1: 0.1302 | top5:
Processing | | (7/5005) Data: 0.011s | Batch: 0.645s | Total: 0:00:33 | ETA: 7:35:32 | Loss: 51.0335 | t op1: 0.1674 | top5:
Processing | | (8/5005) Data: 0.010s | Batch: 0.632s | Total: 0:00:34 | ETA: 6:38:03 | Loss: 208.4966 | t op1: 0.1465
Processing | | (9/5005) Data: 0.010s | Batch: 0.931s | Total: 0:00:35 | ETA: 5:54:48 | Loss: 270749.7812 | t op1: 0.2170
Processing | | (10/5005) Data: 0.010s | Batch: 0.767s | Total: 0:00:35 | ETA: 5:23:56 | Loss: 46110831804416.0000 | t op1:
Processing | | (11/5005) Data: 0.011s | Batch: 0.776s | Total: 0:00:36 | ETA: 4:57:52 | Loss: 42443835703296.0000 | t op1:
The Same set of Regulrization works perfectly fine for CIFAR10/100, but is not working for Imagenet Dataset.
I am essentially Doing in the refularization function:
def l2_reg_ortho(mdl):
l2_reg = None
for w_tmp in mdl.parameters():
if w_tmp.ndimension() < 2:
continue
else:
height = w_tmp.size(0)
u = normalize(w_tmp.new_empty(height).normal_(0,1), dim=0, eps=1e-12)
v = normalize(torch.matmul(w_tmp.t(), u), dim=0, eps=1e-12)
u = normalize(torch.matmul(w_tmp, v), dim=0, eps=1e-12)
sigma = torch.dot(u, torch.matmul(w_tmp, v))
if l2_reg is None:
l2_reg = (torch.norm(sigma,2))**2
else:
l2_reg = l2_reg + (torch.norm(sigma,2))**2
return l2_reg
Adding it to loss:
# compute output
outputs = model(inputs)
oloss = l2_reg_ortho(model)
oloss = odecay * oloss
loss = criterion(outputs, targets)
loss = loss + oloss
Please suggest, I am missing something over here?
Thank You.