Blown up gradients and loss

(Srikanth) #1

I have my entire code written here and the output, I am surprised to see such huge gradients and losses. I am working on this code on CIFAR-10 based on local reparameterization trick by Shayer et.al, I am unable to reproduce the results in the paper. I have been working on this for over a month but no clue where I am going wrong. Could someone help me out Please.

class Ternary_batch_rel(torch.nn.Module):
  def __init__(self,batchnorm_size):
    super(Ternary_batch_rel,self).__init__()
    self.l1=torch.nn.Sequential(
    torch.nn.ReLU(),
    torch.nn.BatchNorm2d(batchnorm_size)
    )
  
  def forward(self,x):
    out=self.l1(x)
    return out
    
z1=Ternary_batch_rel(128).to(device)
z2=Ternary_batch_rel(256).to(device)
z3=Ternary_batch_rel(512).to(device)

class Ternary_max_pool(torch.nn.Module):
  def __init__(self):
    
    super(Ternary_max_pool,self).__init__()
    self.l1=torch.nn.Sequential(
    torch.nn.MaxPool2d(kernel_size=2,stride=2))
      
  def forward(self,x):
    out=self.l1(x)
    return out

zm=Ternary_max_pool().to(device)
sigm=torch.nn.Sigmoid().to(device)

def convlayer1param():
  s1=torch.zeros([128,3,3,3])
  s1=model['layer1.0.weight']
  return s1
 
def convlayer2param():
  s1=torch.zeros([128,128,3,3])
  s1=model['layer1.3.weight']
  return s1

def convlayer3param():
  s1=torch.zeros([256,128,3,3])
  s1=model['layer2.0.weight']
  return s1
def convlayer4param():
  s1=torch.zeros([256,256,3,3])
  s1=model['layer2.3.weight']
  return s1

def convlayer5param():
  s1=torch.zeros([512,256,3,3])
  s1=model['layer3.0.weight']
  return s1

def convlayer6param():
  s1=torch.zeros([512,512,3,3])
  s1=model['layer3.3.weight']
  return s1
def sampling(a,b):
  s=a.shape
  wb_c=torch.zeros(s)
  sigm=torch.nn.Sigmoid()
  p2=sigm(a)
  p1=sigm(b)*(1-sigm(a))
  p3=1-(p2+p1)
  p2=p2.reshape(-1)
  p1=p1.reshape(-1)
  p3=p3.reshape(-1)
  prob=[p1,p2,p3]
  p=torch.stack(prob)
  p=p.t()
  wb_c=torch.tensor(1)-Categorical(p).sample()
  wb_c=wb_c.reshape(s)
  return wb_c.type(torch.float32)


def sampling_testing(al1,bl1,al2,bl2,al3,bl3,al4,bl4,al5,bl5,al6,bl6,a1,b1,a2,b2):
  
  correct=0
  total=0
  wb_c11=sampling(al1,bl1).cuda()
  wb_c12=sampling(al2,bl2).cuda()
  wb_c21=sampling(al3,bl3).cuda()
  wb_c22=sampling(al4,bl4).cuda()
  wb_c31=sampling(al5,bl5).cuda()
  wb_c32=sampling(al6,bl6).cuda()
  wb_fc1=sampling(a1,b1).cuda()
  wb_fc2=sampling(a2,b2).cuda()
    
  for images,labels in valid_loader:
    images=images.to(device)
    labels=labels.to(device)
    yc1=F.conv2d(images,wb_c11,padding=1)
    yc1=z1(yc1)
    yc2=F.conv2d(yc1,wb_c12,padding=1)
    yc2=z1(yc2)
    yc2=zm(yc2)
            
    yc3=F.conv2d(yc2,wb_c21,padding=1)
    yc3=z2(yc3)
    yc4=F.conv2d(yc3,wb_c22,padding=1)
    yc4=z2(yc4)
    yc4=zm(yc4)
            
    yc5=F.conv2d(yc4,wb_c31,padding=1)
    yc5=z3(yc5)
    yc6=F.conv2d(yc5,wb_c32,padding=1)
    yc6=z3(yc6)
    yc6=zm(yc6)
    yc6=yc6.reshape(yc6.size(0),-1)
    y1=F.linear(yc6,wb_fc1)
    y1=F.relu(y1)
    y1=F.dropout(y1)
    y2=F.linear(y1,wb_fc2)
    _,predicted=torch.max(y2,1)
    total+=labels.size(0)
    correct+=(predicted==labels).sum().item()
  print('Test accuracy of the model on the 10000 test images:{}%'.format((correct/total)*100))
  
def reparam(a,b,h):
  weight_m= (2*sigm(b)-(2*sigm(a)*sigm(b))-1+sigm(a))
  weight_v=(1-sigm(a))-weight_m**2
  assert torch.all(weight_v>=0)
  om=F.conv2d(h,weight_m,padding=1)
  ov=F.conv2d(h**2,weight_v,padding=1)
  assert torch.all(ov>=0)
  e=torch.randn_like(ov,device=device)
  z=om+(ov*e)
  return z

def reparamcnn1(a,b,h):
 
  z=reparam(a,b,h)
  return z1(z)

def reparamcnn2(a,b,h):
  z=reparam(a,b,h)
  op=z1(z)
  return zm(op)

def reparamcnn3(a,b,h):
  z=reparam(a,b,h)
  return z2(z)
  
def reparamcnn4(a,b,h):
  z=reparam(a,b,h)
  op=z2(z)
  return zm(op)

def reparamcnn5(a,b,h):
  z=reparam(a,b,h)
  return z3(z)

def reparamcnn6(a,b,h):
  z=reparam(a,b,h)
  op=z3(z)
  return zm(op)

def reparamfc(a,b,h):
    weight_m=(2*sigm(b)-(2*sigm(a)*sigm(b))-1+sigm(a))
    weight_v=(1-sigm(a))-weight_m**2
    assert torch.all(weight_v>=0)
    om=torch.matmul(weight_m,h)
    ov=torch.matmul(weight_v,h**2)
    assert torch.all(ov>=0)
    e=torch.randn_like(ov,device=device)
    z=om+(ov*e)
    return z

def initialize(wfp):
  wtilde=wfp/torch.std(wfp)
  sigma_a=0.95-((0.95-0.05)*torch.abs(wtilde))
  sigma_b=0.5*(1+(wfp/(1-sigma_a)))
  sigma_a=torch.clamp(sigma_a,0.05,0.95)
  sigma_b=torch.clamp(sigma_b,0.05,0.95)
  a=torch.log(sigma_a/(1-sigma_a)).requires_grad_().cuda()
  b=torch.log(sigma_b/(1-sigma_b)).requires_grad_().cuda()
  
  return a,b

w1fpconv=convlayer1param()
w2fpconv=convlayer2param()
w3fpconv=convlayer3param()
w4fpconv=convlayer4param()
w5fpconv=convlayer5param()
w6fpconv=convlayer6param()
wfp1=model['layer4.0.weight']
wfp2=model['layer4.3.weight']
al1,bl1=initialize(w1fpconv)
al2,bl2=initialize(w2fpconv)
al3,bl3=initialize(w3fpconv)
al4,bl4=initialize(w4fpconv)
al5,bl5=initialize(w5fpconv)
al6,bl6=initialize(w6fpconv)
a1,b1=initialize(wfp1)
a2,b2=initialize(wfp2)
al1=torch.nn.Parameter(al1)
bl1=torch.nn.Parameter(bl1)
al2=torch.nn.Parameter(al2)
bl2=torch.nn.Parameter(bl2)
al3=torch.nn.Parameter(al3)
bl3=torch.nn.Parameter(bl3)
al4=torch.nn.Parameter(al4)
bl4=torch.nn.Parameter(bl4)
al5=torch.nn.Parameter(al5)
bl5=torch.nn.Parameter(bl5)
al6=torch.nn.Parameter(al6)
bl6=torch.nn.Parameter(bl6)
a1=torch.nn.Parameter(a1)
a2=torch.nn.Parameter(a2)
b1=torch.nn.Parameter(b1)
b2=torch.nn.Parameter(b2)
betaparam=1e-11
lossfunc=torch.nn.CrossEntropyLoss()
model_params=[al1,bl1,al2,bl2,al3,bl3,al4,bl4,al5,bl5,al6,bl6,a1,b1,a2,b2]
optimizer=torch.optim.Adam(model_params,lr=0.01,weight_decay=1e-4)

num_epochs=300
for epoch in range(num_epochs):
  for i,(images,labels) in enumerate(train_loader):
    #print(i)
    images=images.to(device)
    labels=labels.to(device)
    y1=reparamcnn1(al1,bl1,images)
    y2=reparamcnn2(al2,bl2,y1)
    y3=reparamcnn3(al3,bl3,y2)
    y4=reparamcnn4(al4,bl4,y3)
    y5=reparamcnn5(al5,bl5,y4)
    y6=reparamcnn6(al6,bl6,y5)
    y6=y6.reshape(y6.size(0),-1)
    y6=torch.t(y6)
    y8=reparamfc(a1,b1,y6)
    y9=F.relu(y8)
    y10=F.dropout(y9)
    yout=reparamfc(a2,b2,y10)
    yout=torch.t(yout)
    #l2=al1.norm(2)+bl1.norm(2)+al2.norm(2)+bl2.norm(2)+al3.norm(2)+bl3.norm(2)+al4.norm(2)+bl4.norm(2)+al5.norm(2)+bl5.norm(2)+al6.norm(2)+bl6.norm(2)+a1.norm(2)+b1.norm(2)+a2.norm(2)+b2.norm(2)
    lossi=lossfunc(yout,labels)#+(betaparam*l2)
    if(epoch==170):
      lr=0.001
      for param_group in optimizer.param_groups:
        param_group['lr']=lr
    optimizer.zero_grad()
    lossi.backward()
    optimizer.step()
    
  print('epoch {}'.format(epoch),'loss = {}'.format(lossi.item()))
  print(a1.grad.sum())
  if(epoch%5==0):
    sampling_testing(al1,bl1,al2,bl2,al3,bl3,al4,bl4,al5,bl5,al6,bl6,a1,b1,a2,b2)

The results are here:
epoch 0 loss = 50530525184.0
tensor(-3.1266e+10, device=‘cuda:0’)
Test accuracy of the model on the 10000 test images:10.15%
epoch 1 loss = 811278598144.0
tensor(-4.7532e+11, device=‘cuda:0’)
epoch 2 loss = 86112829440.0
tensor(-5.2734e+10, device=‘cuda:0’)
epoch 3 loss = 158551998464.0
tensor(-1.0209e+11, device=‘cuda:0’)
epoch 4 loss = 69080948736.0
tensor(-4.3354e+10, device=‘cuda:0’)
epoch 5 loss = 244970094592.0
tensor(-1.4066e+11, device=‘cuda:0’)
Test accuracy of the model on the 10000 test images:11.129999999999999%
epoch 6 loss = 61311873024.0
tensor(-3.6518e+10, device=‘cuda:0’)
epoch 7 loss = 100857921536.0
tensor(-6.2719e+10, device=‘cuda:0’)
epoch 8 loss = 21799936000.0
tensor(-1.3609e+10, device=‘cuda:0’)
epoch 9 loss = 35716399104.0
tensor(-2.2920e+10, device=‘cuda:0’)
epoch 10 loss = 4038941440.0
tensor(-2.5166e+09, device=‘cuda:0’)
Test accuracy of the model on the 10000 test images:11.59%

and so on…

Low accuracy when loading the model and testing
(Deeply) #2

I am not sure about your code, but the first thing I would do is move optimizer.zero_grad() after moving the data to device.

Although might have nothing to do with your problems, try to use lr_scheduler to decay the learning rate, something like:

scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, lr_milestones , gamma= lr_gamma)

Then, replace if (epoch==170) block with
scheduler.step()

(Srikanth) #3

@Deeply. Thank you for your reply. I tried both the things but the gradients are still too high is the accuracy is very bad

Is there any other solution you can provide?