I have my entire code written here and the output, I am surprised to see such huge gradients and losses. I am working on this code on CIFAR-10 based on local reparameterization trick by Shayer et.al, I am unable to reproduce the results in the paper. I have been working on this for over a month but no clue where I am going wrong. Could someone help me out Please.
class Ternary_batch_rel(torch.nn.Module):
def __init__(self,batchnorm_size):
super(Ternary_batch_rel,self).__init__()
self.l1=torch.nn.Sequential(
torch.nn.ReLU(),
torch.nn.BatchNorm2d(batchnorm_size)
)
def forward(self,x):
out=self.l1(x)
return out
z1=Ternary_batch_rel(128).to(device)
z2=Ternary_batch_rel(256).to(device)
z3=Ternary_batch_rel(512).to(device)
class Ternary_max_pool(torch.nn.Module):
def __init__(self):
super(Ternary_max_pool,self).__init__()
self.l1=torch.nn.Sequential(
torch.nn.MaxPool2d(kernel_size=2,stride=2))
def forward(self,x):
out=self.l1(x)
return out
zm=Ternary_max_pool().to(device)
sigm=torch.nn.Sigmoid().to(device)
def convlayer1param():
s1=torch.zeros([128,3,3,3])
s1=model['layer1.0.weight']
return s1
def convlayer2param():
s1=torch.zeros([128,128,3,3])
s1=model['layer1.3.weight']
return s1
def convlayer3param():
s1=torch.zeros([256,128,3,3])
s1=model['layer2.0.weight']
return s1
def convlayer4param():
s1=torch.zeros([256,256,3,3])
s1=model['layer2.3.weight']
return s1
def convlayer5param():
s1=torch.zeros([512,256,3,3])
s1=model['layer3.0.weight']
return s1
def convlayer6param():
s1=torch.zeros([512,512,3,3])
s1=model['layer3.3.weight']
return s1
def sampling(a,b):
s=a.shape
wb_c=torch.zeros(s)
sigm=torch.nn.Sigmoid()
p2=sigm(a)
p1=sigm(b)*(1-sigm(a))
p3=1-(p2+p1)
p2=p2.reshape(-1)
p1=p1.reshape(-1)
p3=p3.reshape(-1)
prob=[p1,p2,p3]
p=torch.stack(prob)
p=p.t()
wb_c=torch.tensor(1)-Categorical(p).sample()
wb_c=wb_c.reshape(s)
return wb_c.type(torch.float32)
def sampling_testing(al1,bl1,al2,bl2,al3,bl3,al4,bl4,al5,bl5,al6,bl6,a1,b1,a2,b2):
correct=0
total=0
wb_c11=sampling(al1,bl1).cuda()
wb_c12=sampling(al2,bl2).cuda()
wb_c21=sampling(al3,bl3).cuda()
wb_c22=sampling(al4,bl4).cuda()
wb_c31=sampling(al5,bl5).cuda()
wb_c32=sampling(al6,bl6).cuda()
wb_fc1=sampling(a1,b1).cuda()
wb_fc2=sampling(a2,b2).cuda()
for images,labels in valid_loader:
images=images.to(device)
labels=labels.to(device)
yc1=F.conv2d(images,wb_c11,padding=1)
yc1=z1(yc1)
yc2=F.conv2d(yc1,wb_c12,padding=1)
yc2=z1(yc2)
yc2=zm(yc2)
yc3=F.conv2d(yc2,wb_c21,padding=1)
yc3=z2(yc3)
yc4=F.conv2d(yc3,wb_c22,padding=1)
yc4=z2(yc4)
yc4=zm(yc4)
yc5=F.conv2d(yc4,wb_c31,padding=1)
yc5=z3(yc5)
yc6=F.conv2d(yc5,wb_c32,padding=1)
yc6=z3(yc6)
yc6=zm(yc6)
yc6=yc6.reshape(yc6.size(0),-1)
y1=F.linear(yc6,wb_fc1)
y1=F.relu(y1)
y1=F.dropout(y1)
y2=F.linear(y1,wb_fc2)
_,predicted=torch.max(y2,1)
total+=labels.size(0)
correct+=(predicted==labels).sum().item()
print('Test accuracy of the model on the 10000 test images:{}%'.format((correct/total)*100))
def reparam(a,b,h):
weight_m= (2*sigm(b)-(2*sigm(a)*sigm(b))-1+sigm(a))
weight_v=(1-sigm(a))-weight_m**2
assert torch.all(weight_v>=0)
om=F.conv2d(h,weight_m,padding=1)
ov=F.conv2d(h**2,weight_v,padding=1)
assert torch.all(ov>=0)
e=torch.randn_like(ov,device=device)
z=om+(ov*e)
return z
def reparamcnn1(a,b,h):
z=reparam(a,b,h)
return z1(z)
def reparamcnn2(a,b,h):
z=reparam(a,b,h)
op=z1(z)
return zm(op)
def reparamcnn3(a,b,h):
z=reparam(a,b,h)
return z2(z)
def reparamcnn4(a,b,h):
z=reparam(a,b,h)
op=z2(z)
return zm(op)
def reparamcnn5(a,b,h):
z=reparam(a,b,h)
return z3(z)
def reparamcnn6(a,b,h):
z=reparam(a,b,h)
op=z3(z)
return zm(op)
def reparamfc(a,b,h):
weight_m=(2*sigm(b)-(2*sigm(a)*sigm(b))-1+sigm(a))
weight_v=(1-sigm(a))-weight_m**2
assert torch.all(weight_v>=0)
om=torch.matmul(weight_m,h)
ov=torch.matmul(weight_v,h**2)
assert torch.all(ov>=0)
e=torch.randn_like(ov,device=device)
z=om+(ov*e)
return z
def initialize(wfp):
wtilde=wfp/torch.std(wfp)
sigma_a=0.95-((0.95-0.05)*torch.abs(wtilde))
sigma_b=0.5*(1+(wfp/(1-sigma_a)))
sigma_a=torch.clamp(sigma_a,0.05,0.95)
sigma_b=torch.clamp(sigma_b,0.05,0.95)
a=torch.log(sigma_a/(1-sigma_a)).requires_grad_().cuda()
b=torch.log(sigma_b/(1-sigma_b)).requires_grad_().cuda()
return a,b
w1fpconv=convlayer1param()
w2fpconv=convlayer2param()
w3fpconv=convlayer3param()
w4fpconv=convlayer4param()
w5fpconv=convlayer5param()
w6fpconv=convlayer6param()
wfp1=model['layer4.0.weight']
wfp2=model['layer4.3.weight']
al1,bl1=initialize(w1fpconv)
al2,bl2=initialize(w2fpconv)
al3,bl3=initialize(w3fpconv)
al4,bl4=initialize(w4fpconv)
al5,bl5=initialize(w5fpconv)
al6,bl6=initialize(w6fpconv)
a1,b1=initialize(wfp1)
a2,b2=initialize(wfp2)
al1=torch.nn.Parameter(al1)
bl1=torch.nn.Parameter(bl1)
al2=torch.nn.Parameter(al2)
bl2=torch.nn.Parameter(bl2)
al3=torch.nn.Parameter(al3)
bl3=torch.nn.Parameter(bl3)
al4=torch.nn.Parameter(al4)
bl4=torch.nn.Parameter(bl4)
al5=torch.nn.Parameter(al5)
bl5=torch.nn.Parameter(bl5)
al6=torch.nn.Parameter(al6)
bl6=torch.nn.Parameter(bl6)
a1=torch.nn.Parameter(a1)
a2=torch.nn.Parameter(a2)
b1=torch.nn.Parameter(b1)
b2=torch.nn.Parameter(b2)
betaparam=1e-11
lossfunc=torch.nn.CrossEntropyLoss()
model_params=[al1,bl1,al2,bl2,al3,bl3,al4,bl4,al5,bl5,al6,bl6,a1,b1,a2,b2]
optimizer=torch.optim.Adam(model_params,lr=0.01,weight_decay=1e-4)
num_epochs=300
for epoch in range(num_epochs):
for i,(images,labels) in enumerate(train_loader):
#print(i)
images=images.to(device)
labels=labels.to(device)
y1=reparamcnn1(al1,bl1,images)
y2=reparamcnn2(al2,bl2,y1)
y3=reparamcnn3(al3,bl3,y2)
y4=reparamcnn4(al4,bl4,y3)
y5=reparamcnn5(al5,bl5,y4)
y6=reparamcnn6(al6,bl6,y5)
y6=y6.reshape(y6.size(0),-1)
y6=torch.t(y6)
y8=reparamfc(a1,b1,y6)
y9=F.relu(y8)
y10=F.dropout(y9)
yout=reparamfc(a2,b2,y10)
yout=torch.t(yout)
#l2=al1.norm(2)+bl1.norm(2)+al2.norm(2)+bl2.norm(2)+al3.norm(2)+bl3.norm(2)+al4.norm(2)+bl4.norm(2)+al5.norm(2)+bl5.norm(2)+al6.norm(2)+bl6.norm(2)+a1.norm(2)+b1.norm(2)+a2.norm(2)+b2.norm(2)
lossi=lossfunc(yout,labels)#+(betaparam*l2)
if(epoch==170):
lr=0.001
for param_group in optimizer.param_groups:
param_group['lr']=lr
optimizer.zero_grad()
lossi.backward()
optimizer.step()
print('epoch {}'.format(epoch),'loss = {}'.format(lossi.item()))
print(a1.grad.sum())
if(epoch%5==0):
sampling_testing(al1,bl1,al2,bl2,al3,bl3,al4,bl4,al5,bl5,al6,bl6,a1,b1,a2,b2)
The results are here:
epoch 0 loss = 50530525184.0
tensor(-3.1266e+10, device=‘cuda:0’)
Test accuracy of the model on the 10000 test images:10.15%
epoch 1 loss = 811278598144.0
tensor(-4.7532e+11, device=‘cuda:0’)
epoch 2 loss = 86112829440.0
tensor(-5.2734e+10, device=‘cuda:0’)
epoch 3 loss = 158551998464.0
tensor(-1.0209e+11, device=‘cuda:0’)
epoch 4 loss = 69080948736.0
tensor(-4.3354e+10, device=‘cuda:0’)
epoch 5 loss = 244970094592.0
tensor(-1.4066e+11, device=‘cuda:0’)
Test accuracy of the model on the 10000 test images:11.129999999999999%
epoch 6 loss = 61311873024.0
tensor(-3.6518e+10, device=‘cuda:0’)
epoch 7 loss = 100857921536.0
tensor(-6.2719e+10, device=‘cuda:0’)
epoch 8 loss = 21799936000.0
tensor(-1.3609e+10, device=‘cuda:0’)
epoch 9 loss = 35716399104.0
tensor(-2.2920e+10, device=‘cuda:0’)
epoch 10 loss = 4038941440.0
tensor(-2.5166e+09, device=‘cuda:0’)
Test accuracy of the model on the 10000 test images:11.59%
and so on…