@rasbt, thank you for your reply.
I have enclosed my entire code in here. The parameters of my network are intialized using the the trained weights of a full precision neural network with the same architecture
#load all the necessary libraries
import torch
import torch.nn
import numpy as np
torch.backends.cudnn.deterministic=True
#from utils import plot_images
from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
import torchvision
import torch.nn.functional as F
from torch.distributions import Categorical
device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
sigm=torch.nn.Sigmoid()
#import the CIFAR-10 datasets
def get_train_valid_loader(data_dir,
batch_size,
augment,
random_seed,
valid_size=0.1,
shuffle=True,
show_sample=False,
num_workers=4,
pin_memory=False):
error_msg = "[!] valid_size should be in the range [0, 1]."
assert ((valid_size >= 0) and (valid_size <= 1)), error_msg
normalize = transforms.Normalize(
mean=[0.4914, 0.4822, 0.4465],
std=[0.2023, 0.1994, 0.2010],
)
# define transforms
valid_transform = transforms.Compose([
transforms.ToTensor(),
normalize,
])
if augment:
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
else:
train_transform = transforms.Compose([
transforms.ToTensor(),
normalize,
])
# load the dataset
train_dataset = datasets.CIFAR10(
root=data_dir, train=True,
download=True, transform=train_transform,
)
valid_dataset = datasets.CIFAR10(
root=data_dir, train=True,
download=True, transform=valid_transform,
)
num_train = len(train_dataset)
indices = list(range(num_train))
split = int(np.floor(valid_size * num_train))
if shuffle:
np.random.seed(random_seed)
np.random.shuffle(indices)
train_idx, valid_idx = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, sampler=train_sampler,
num_workers=num_workers, pin_memory=pin_memory,
)
valid_loader = torch.utils.data.DataLoader(
valid_dataset, batch_size=batch_size, sampler=valid_sampler,
num_workers=num_workers, pin_memory=pin_memory,
)
# visualize some images
if show_sample:
sample_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=9, shuffle=shuffle,
num_workers=num_workers, pin_memory=pin_memory,
)
data_iter = iter(sample_loader)
images, labels = data_iter.next()
X = images.numpy().transpose([0, 2, 3, 1])
plot_images(X, labels)
return (train_loader, valid_loader)
train_loader,valid_loader=get_train_valid_loader(data_dir='C://Users//AEON-LAB PC//.spyder-py3//CIFAR_10',
batch_size=128,
augment=True,
random_seed=999,
valid_size=0.2,
shuffle=True,
show_sample=False,
num_workers=1,
pin_memory=True)
#Class that performs convolution operation according to the reparameterization trick (Shayer et.al)
class Repcnn(torch.nn.Module):
def __init__(self,wfp):
super(Repcnn,self).__init__()
self.a,self.b=self.initialize(wfp)
#print(self.a.norm())
def initialize(self,wfp):
wtilde=wfp/torch.std(wfp)
sigma_a=0.95-((0.95-0.05)*torch.abs(wtilde))
sigma_b=0.5*(1+(wfp/(1-sigma_a)))
sigma_a=torch.clamp(sigma_a,0.05,0.95)
sigma_b=torch.clamp(sigma_b,0.05,0.95)
a=torch.log(sigma_a/(1-sigma_a)).requires_grad_().cuda()
b=torch.log(sigma_b/(1-sigma_b)).requires_grad_().cuda()
return torch.nn.Parameter(a),torch.nn.Parameter(b)
def forward(self,x):
weight_m= (2*sigm(self.b)-(2*sigm(self.a)*sigm(self.b))-1+sigm(self.a))
#print(self.a.norm())
weight_v=(1-sigm(self.a))-weight_m**2
assert torch.all(weight_v>=0)
om=F.conv2d(x,weight_m,padding=1)
ov=F.conv2d(x**2,weight_v,padding=1)
assert torch.all(ov>=0)
#e=torch.randn_like(ov).cuda()
e=torch.randn_like(ov).cuda()
z=om+(ov*e)
return z
#class that performs the the linear operation in the fully connected layers using reparmeterization trick
class Repfc(torch.nn.Module):
def __init__(self,wfp):
super(Repfc,self).__init__()
self.a1,self.b1=self.initialize(wfp)
def initialize(self,wfp):
wtilde=wfp/torch.std(wfp)
sigma_a=0.95-((0.95-0.05)*torch.abs(wtilde))
sigma_b=0.5*(1+(wfp/(1-sigma_a)))
sigma_a=torch.clamp(sigma_a,0.05,0.95)
sigma_b=torch.clamp(sigma_b,0.05,0.95)
a=torch.log(sigma_a/(1-sigma_a))
b=torch.log(sigma_b/(1-sigma_b))
return torch.nn.Parameter(a),torch.nn.Parameter(b)
def forward(self,x):
weight_m=(2*sigm(self.b1)-(2*sigm(self.a1)*sigm(self.b1))-1+sigm(self.a1))
weight_v=(1-sigm(self.a1))-weight_m**2
om=torch.matmul(weight_m,x)
ov=torch.matmul(weight_v,x**2)
#e=torch.randn_like(ov).cuda()
e=torch.randn_like(ov).cuda()
z=om+(ov*e)
return z
#weight intialization using the full precision trained network
model=torch.load('/content/vgg_8_without_relu.pth',map_location='cpu')
wfp=[]
wfp.append(model['features.0.weight'])
wfp.append(model['features.3.weight'])
wfp.append(model['features.7.weight'])
wfp.append(model['features.10.weight'])
wfp.append(model['features.14.weight'])
wfp.append(model['features.17.weight'])
wfp.append(model['classifier.1.weight'])
wfp.append(model['classifier.2.weight'])
for i in range(len(wfp)):
wfp[i]=torch.Tensor(wfp[i])
#Forward propagation and training
class Conv_Net(torch.nn.Module):
def __init__(self,wfp):
super(Conv_Net,self).__init__()
self.hidden=torch.nn.ModuleList([])
self.batchnorm=torch.nn.ModuleList([])
for i in range(6):
cnn=Repcnn(wfp[i])
self.hidden.append(cnn)
for j in range(2):
fc=Repfc(wfp[i+1])
i+=1
self.hidden.append(fc)
batch_dim=[128,256,512]
for i in batch_dim:
self.batchnorm.append(torch.nn.BatchNorm2d(i))
self.mp=torch.nn.MaxPool2d(kernel_size=2,stride=2)
self.drop=torch.nn.Dropout()
self.activation=torch.nn.ReLU()
def forward(self,x):
op=x
j=0
while(j<6):
obj=self.hidden[j]
obj_next=self.hidden[j+1]
b=self.batchnorm[j//2]
j+=2
op=self.mp(self.activation(b(obj_next(self.activation(b(obj(op)))))))
op=op.view(op.size(0),-1)
op=torch.t(op)
obj=self.hidden[j]
op=obj(self.drop(op))
j+=1
obj=self.hidden[j]
yout=obj(op)
yout=torch.t(yout)
#print(yout)
return yout
net=Conv_Net(wfp).to(device)
def l2_reg():
sum=0
for p in net.parameters():
sum+=p.norm(2)
return sum
l_rate=0.01
#lr_decay=20
beta_param=1e-11
weight_decay=1e-11
optimizer=torch.optim.Adam(net.parameters(),lr=l_rate,weight_decay=weight_decay)
criterion=torch.nn.CrossEntropyLoss().cuda()
net.train()
num_epochs=300
for epoch in range(num_epochs):
if(epoch==170):
lr=0.001
for param_group in optimizer.param_groups:
param_group['lr']=lr
for i,(images,labels) in enumerate(train_loader):
images=images.to(device)
labels=labels.to(device)
#print(i)
#torch.cuda.empty_cache()
optimizer.zero_grad()
yout=net(images)
loss_batch=criterion(yout,labels)+(beta_param*l2_reg())
loss_batch.backward()
optimizer.step()
print('epoch {}'.format(epoch),'loss {}'.format(loss_batch.item()))
sum_grad=0
for p in net.parameters():
sum_grad+=p.grad.norm()
print('sum of the gradients of all parameters in epoch{} is {}'.format(epoch,sum_grad))
#evaluation
net.eval()
with torch.no_grad():
correct=0
total=0
for images,labels in valid_loader:
images=images.to(device)
labels=labels.to(device)
yout=net(images)
_,predicted=torch.max(yout,1)
total+=labels.size(0)
correct+=(predicted==labels).sum().item()
print('Test accuracy of the model on the 10000 test images:{}%'.format((correct/total)*100))
The problem here is that the loss decreases down untill 2.30258 and after it stays constant.