import os
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import torch.nn.functional as F
import argparse
import shutil
import time
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import numpy as np
import numpy
import util
# set the seed
torch.manual_seed(1)
torch.cuda.manual_seed(1)
import sys
import gc
parser = argparse.ArgumentParser(description='Alexnet')
parser.add_argument('--arch', '-a', metavar='ARCH', default='alexnet',
help='model architecture (default: alexnet)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
help='evaluate model on validation set')
parser.add_argument('-m', '--trained_model', default='weights/alexnet.baseline.pytorch.pth.tar',
type=str, help='Trained state_dict file path to open')
parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,
metavar='LR', help='initial learning rate')
parser.add_argument('--weight-decay', '--wd', default=1e-5, type=float,
metavar='W', help='weight decay (default: 1e-5)')
parser.add_argument('--data', metavar='DATA_PATH', default='./data/',
help='path to imagenet data (default: ./data/)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
help='number of data loading workers (default: 8)')
parser.add_argument('--print-freq', '-p', default=10, type=int,
metavar='N', help='print frequency (default: 10)')
args = parser.parse_args()
__all__ = ['AlexNet', 'alexnet']
best_prec1 = 0
bin_op = None
c1 =0
count =0
class DummyModule(nn.Module):
def __init__(self):
super(DummyModule, self).__init__()
def forward(self, x):
# print("Dummy, Dummy.")
return x
def fuse(conv, bn):
w = conv.weight
#print(w)
mean = bn.running_mean
var_sqrt = torch.sqrt(bn.running_var + bn.eps)
beta = bn.weight
gamma = bn.bias
if conv.bias is not None:
b = conv.bias
else:
b = mean.new_zeros(mean.shape)
w = w * (beta / var_sqrt).reshape([conv.out_channels, 1, 1, 1])
b = (b - mean)/var_sqrt * beta + gamma
fused_conv = nn.Conv2d(conv.in_channels,
conv.out_channels,
conv.kernel_size,
conv.stride,
conv.padding,
bias=True)
fused_conv.weight = nn.Parameter(w)
fused_conv.bias = nn.Parameter(b)
return fused_conv
def fuse_module(m):
children = list(m.named_children())
c = None
cn = None
for name, child in children:
#print("name is",name,"child is",child)
if isinstance(child, nn.BatchNorm2d):
bc = fuse(c, child)
m._modules[cn] = bc
print('hi',m._modules['0'])
m._modules[name] = DummyModule()
c = None
elif isinstance(child, nn.Conv2d):
c = child
cn = name
elif isinstance(child,BinConv2d):
break
else:
#print(child)
fuse_module(child)
class DummyModule_1(nn.Module):
def __init__(self):
super(DummyModule_1, self).__init__()
def forward(self, x):
# print("Dummy, Dummy.")
return x
def fuse_1(linear, bn):
w = linear.weight
#print(w)
mean = bn.running_mean
#print('mean',mean)
var_sqrt = torch.sqrt(bn.running_var + bn.eps)
#print(var_sqrt)
beta = bn.weight
#print('beta',beta)
gamma = bn.bias
#print("gamma",gamma)
if linear.bias is not None:
b = linear.bias
else:
b = mean.new_zeros(mean.shape)
w = w.cuda()
b = b.cuda()
w = w * (beta / var_sqrt).reshape([4096, 1])
#y = w.cpu().detach().numpy()
#np.savetxt("weight.txt",y,fmt='%.6f')
b = (b - mean)/var_sqrt * beta + gamma
#print(b)
j = b.cpu().detach().numpy()
np.savetxt("bias.txt",j,fmt='%.6f')
fused_linear = nn.Linear(linear.in_features,
linear.out_features)
fused_linear.weight = nn.Parameter(w)
#print(fused_linear.weight)
fused_linear.bias = nn.Parameter(b)
return fused_linear
def fuse_module_1(m):
children = list(m.named_children())
c = None
cn = None
global c1
global count
global c18
for name, child in children:
#print("name is",name,"child is",child)
"""if name == 'classifier':
print("inside class")
fuse_module_1(child)
if name == '2' and isinstance(child,nn.BatchNorm1d):
c = child
cn = name
elif name == '4' and isinstance(child,nn.Linear):
print("children is",child)
bc = fuse_1(child,c)
m.classifier[4] = bc
m._classifier[2] = DummyModule_1()
if name == 'relu':
c1 = c1+1"""
if name == 'linear':
count = count+1
if count == 2:
c18 = child
print("c18 is",c18)
else:
fuse_module_1(child)
"""elif c1 == 6:
#print('c18 is',c18)
fuse_module_1(child)
break"""
if name =='2' and isinstance(child,nn.BatchNorm1d):
print("child is",child)
bc = fuse_1(c18,child)
print(m)
m.classifier[1].linear = bc
m.classifier[2] = DummyModule_1()
else:
#fuse_module_1(child)
fuse_module_1(child)
class BinActive(torch.autograd.Function):
'''
Binarize the input activations and calculate the mean across channel dimension.
'''
def forward(self, input):
self.save_for_backward(input)
size = input.size()
input = input.sign()
return input
def backward(self, grad_output):
input, = self.saved_tensors
grad_input = grad_output.clone()
grad_input[input.ge(1)] = 0
grad_input[input.le(-1)] = 0
return grad_input
class BinConv2d(nn.Module): # change the name of BinConv2d
def __init__(self, input_channels, output_channels,
kernel_size=-1, stride=-1, padding=-1, groups=1, dropout=0,
Linear=False):
super(BinConv2d, self).__init__()
self.layer_type = 'BinConv2d'
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dropout_ratio = dropout
if dropout!=0:
self.dropout = nn.Dropout(dropout)
self.Linear = Linear
if not self.Linear:
self.bn = nn.BatchNorm2d(input_channels, eps=1e-4, momentum=0.1, affine=True)
self.conv = nn.Conv2d(input_channels, output_channels,
kernel_size=kernel_size, stride=stride, padding=padding, groups=groups)
else:
self.bn = nn.BatchNorm1d(input_channels, eps=1e-4, momentum=0.1, affine=True)
self.linear = nn.Linear(input_channels, output_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.bn(x)
x = BinActive()(x)
if self.dropout_ratio!=0:
x = self.dropout(x)
if not self.Linear:
x = self.conv(x)
else:
x = self.linear(x)
x = self.relu(x)
return x
class AlexNet(nn.Module):
def __init__(self, num_classes=1000):
super(AlexNet, self).__init__()
self.num_classes = num_classes
self.features = nn.Sequential(
nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=0),
nn.BatchNorm2d(96, eps=1e-4, momentum=0.1, affine=True),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
BinConv2d(96, 256, kernel_size=5, stride=1, padding=2, groups=1),
nn.MaxPool2d(kernel_size=3, stride=2),
BinConv2d(256, 384, kernel_size=3, stride=1, padding=1),
BinConv2d(384, 384, kernel_size=3, stride=1, padding=1, groups=1),
BinConv2d(384, 256, kernel_size=3, stride=1, padding=1, groups=1),
nn.MaxPool2d(kernel_size=3, stride=2),
)
self.classifier = nn.Sequential(
BinConv2d(256 * 6 * 6, 4096, Linear=True),
BinConv2d(4096, 4096, dropout=0.5, Linear=True),
nn.BatchNorm1d(4096, eps=1e-3, momentum=0.1, affine=True),
nn.Dropout(),
nn.Linear(4096, num_classes),
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), 256 * 6 * 6)
x = self.classifier(x)
return x
"""def printbn(self, input, output):
print('Inside ' + self.__class__.__name__ + ' forward')
mean = input[0].mean(dim=0)
var = input[0].var(dim=0)
print(mean)"""
def alexnet(pretrained=False, **kwargs):
r"""AlexNet model architecture from the
`"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = AlexNet(**kwargs)
input_size = 227
#model.bn_layer.register_forward_hook(printbn)
#model.forward()
if pretrained:
model_path = 'model_list/alexnet.pth.tar'
pretrained_model = torch.load(model_path)
model.load_state_dict(pretrained_model['state_dict'])
#return model
if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
model.features = torch.nn.DataParallel(model.features)
model.cuda()
else:
model = torch.nn.DataParallel(model).cuda()
criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.Adam(model.parameters(), args.lr,
weight_decay=args.weight_decay)
for m in model.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
c = float(m.weight.data[0].nelement())
m.weight.data = m.weight.data.normal_(0, 2.0/c)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data = m.weight.data.zero_().add(1.0)
m.bias.data = m.bias.data.zero_()
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
del checkpoint
else:
print("=> no checkpoint found at '{}'".format(args.resume))
#print("before fusing model",model)
#print(model)
fuse_module(model)
fuse_module_1(model)
#print(model._modules[0].weight.size())
cudnn.benchmark = True
"""for key, value in model.named_parameters():
np.set_printoptions(threshold=50000000,formatter={'float_kind':'{:f}'.format})
str1=".txt"
str3=".shape"
str2='./'+'merged_params/'+'_'+key+str1
str4='./'+'merged_params/'+'_'+key+str3+str1
print(str2)
print(str4)
#array=np.asarray(value)
s=value.cpu().detach().numpy()
#print(s)
file2=open(str2,"w")
data=s[:]
file2.write(str(data))
#file2.write(s)
file2.close()
file3=open(str4,"w")
#data=s[:]
file3.write(str(value.shape))
#file2.write(s)
file3.close()"""
print('==> Using Pytorch Dataset')
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
traindir = os.path.join(args.data, 'train')
valdir = os.path.join(args.data, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[1./255., 1./255., 1./255.])
torchvision.set_image_backend('PIL')
val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(valdir, transforms.Compose([
transforms.Resize((256, 256)),
transforms.CenterCrop(input_size),
transforms.ToTensor(),
normalize,
])),
batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True)
global bin_op
bin_op = util.BinOp(model)
if args.evaluate:
validate(val_loader, model, criterion)
return
def validate(val_loader, model, criterion):
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
# switch to evaluate mode
model.eval()
#print("after fusing model",model)
end = time.time()
bin_op.binarization()
for i, (input, target) in enumerate(val_loader):
target = target.cuda(async=True)
with torch.no_grad():
input_var = torch.autograd.Variable(input)
target_var = torch.autograd.Variable(target)
# compute output
output = model(input_var)
loss = criterion(output, target_var)
values,indices=torch.max(output,1)
#print("output is",output)
print("maximum value and its indices is",values,indices)
#print("target is",target_var)
# measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
losses.update(loss.data.item(), input.size(0))
top1.update(prec1[0], input.size(0))
top5.update(prec5[0], input.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
i, len(val_loader), batch_time=batch_time, loss=losses,
top1=top1, top5=top5))
bin_op.restore()
print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))
return top1.avg
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, 'model_best.pth.tar')
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def adjust_learning_rate(optimizer, epoch):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
lr = args.lr * (0.1 ** (epoch // 30))
print 'Learning rate:', lr
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
#checkpoint = torch.load(args.resume)```