How did you implement it in the end?
Could you post your code and where you would fuse parameters to save operations?
sorry for late reply,
-
I have gone through this paper https://arxiv.org/pdf/2002.11018v1.pdf(page no:3,equation 9)
-
have a look into this code:
import os
import torch.nn as nn
import torch
import os
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import torch.nn.functional as F
import argparse
import shutil
import time
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import numpy as np
import numpy
import util
from PIL import Image
import torch.utils.model_zoo as model_zoo
import torch.nn.functional as F
# set the seed
torch.manual_seed(1)
torch.cuda.manual_seed(1)
import sys
import gc
parser = argparse.ArgumentParser(description='Alexnet')
parser.add_argument('--arch', '-a', metavar='ARCH', default='alexnet',
help='model architecture (default: alexnet)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
help='evaluate model on validation set')
parser.add_argument('-m', '--trained_model', default='weights/alexnet.baseline.pytorch.pth.tar',
type=str, help='Trained state_dict file path to open')
parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,
metavar='LR', help='initial learning rate')
parser.add_argument('--weight-decay', '--wd', default=1e-5, type=float,
metavar='W', help='weight decay (default: 1e-5)')
parser.add_argument('--data', metavar='DATA_PATH', default='./data/',
help='path to imagenet data (default: ./data/)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
help='number of data loading workers (default: 8)')
parser.add_argument('--print-freq', '-p', default=10, type=int,
metavar='N', help='print frequency (default: 10)')
args = parser.parse_args()
__all__ = ['AlexNet', 'alexnet']
best_prec1 = 0
bin_op = None
c1 =0
count =0
count1 = 0
class DummyModule(nn.Module):
def __init__(self):
super(DummyModule, self).__init__()
def forward(self, x):
# print("Dummy, Dummy.")
return x
def fuse(conv, bn):
w = conv.weight
#print(w)
mean = bn.running_mean
var_sqrt = torch.sqrt(bn.running_var + bn.eps)
beta = bn.weight
gamma = bn.bias
if conv.bias is not None:
b = conv.bias
else:
b = mean.new_zeros(mean.shape)
w1 = (beta/var_sqrt)
w1 = w1.reshape([conv.out_channels, 1, 1, 1])
print(w1.size())
w = w * w1
#w = torch.matmul(w,w1)
#w = w * (beta / var_sqrt).reshape([conv.out_channels, 1, 1, 1])
b = (b - mean)/var_sqrt * beta + gamma
fused_conv = nn.Conv2d(conv.in_channels,
conv.out_channels,
conv.kernel_size,
conv.stride,
conv.padding,
bias=True)
fused_conv.weight = nn.Parameter(w)
fused_conv.bias = nn.Parameter(b)
return fused_conv
def fuse_module(m):
children = list(m.named_children())
c = None
cn = None
for name, child in children:
#print("name is",name,"child is",child)
if isinstance(child, nn.BatchNorm2d):
bc = fuse(c, child)
m._modules[cn] = bc
#print('hi',m._modules['0'])
m._modules[name] = DummyModule()
c = None
elif isinstance(child, nn.Conv2d):
c = child
cn = name
elif isinstance(child,BinConv2d):
break
else:
#print(child)
fuse_module(child)
class DummyModule_1(nn.Module):
def __init__(self):
super(DummyModule_1, self).__init__()
def forward(self, x):
# print("Dummy, Dummy.")
return x
def fuse_1(linear, bn):
w = linear.weight
mean = bn.running_mean
var_sqrt = torch.sqrt(bn.running_var + bn.eps)
beta = bn.weight
gamma = bn.bias
if linear.bias is not None:
b = linear.bias
else:
b = mean.new_zeros(mean.shape)
w = w.cuda()
b = b.cuda()
w = (w * beta)/var_sqrt
k = (gamma - beta*mean/var_sqrt)
#v = w * (gamma - beta*mean/var_sqrt)
#v = v.sum()
#v = w * k
v = torch.matmul(w,k)
print(v.size())
print(b.size())
b = b + v
#print(b.size())
fused_linear = nn.Linear(linear.in_features,
linear.out_features)
fused_linear.weight = nn.Parameter(w)
fused_linear.bias = nn.Parameter(b)
return fused_linear
def fuse_module_1(m):
children = list(m.named_children())
c = None
cn = None
global c1
global count
global c18
for name, child in children:
#print("name is",name,"child is",child)
if name == '4' and isinstance(child,nn.Linear):
#count = count+1
#print("count is",count)
#if count == 3:
print("child is",child)
bc = fuse_1(child,c18)
m[4] = bc
m[2] = DummyModule_1()
#else:
#fuse_module_1(child)
elif name =='2' and isinstance(child,nn.BatchNorm1d):
c18 = child
print("c18 is",c18)
fuse_module_1(child)
else:
#fuse_module_1(child)
fuse_module_1(child)
class BinActive(torch.autograd.Function):
'''
Binarize the input activations and calculate the mean across channel dimension.
'''
def forward(self, input):
self.save_for_backward(input)
size = input.size()
input = input.sign()
return input
def backward(self, grad_output):
input, = self.saved_tensors
grad_input = grad_output.clone()
grad_input[input.ge(1)] = 0
grad_input[input.le(-1)] = 0
return grad_input
class BinConv2d(nn.Module): # change the name of BinConv2d
def __init__(self, input_channels, output_channels,
kernel_size=-1, stride=-1, padding=-1, groups=1, dropout=0,
Linear=False):
super(BinConv2d, self).__init__()
self.layer_type = 'BinConv2d'
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dropout_ratio = dropout
if dropout!=0:
self.dropout = nn.Dropout(dropout)
self.Linear = Linear
if not self.Linear:
self.bn = nn.BatchNorm2d(input_channels, eps=1e-4, momentum=0.1, affine=True)
self.conv = nn.Conv2d(input_channels, output_channels,
kernel_size=kernel_size, stride=stride, padding=padding, groups=groups)
else:
self.bn = nn.BatchNorm1d(input_channels, eps=1e-4, momentum=0.1, affine=True)
self.linear = nn.Linear(input_channels, output_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
global count1
x = self.bn(x)
x = BinActive()(x)
if self.dropout_ratio!=0:
x = self.dropout(x)
if not self.Linear:
x = self.conv(x)
else:
x = self.linear(x)
x = self.relu(x)
print(x.size())
"""if x.size() == (1,4096):
count1 = count1+1
if count1 == 2:
y = x.cpu().detach().numpy()
np.savetxt("1_fused_relu.output.csv",y,fmt='%.6f',delimiter = ',')"""
return x
class AlexNet(nn.Module):
def __init__(self, num_classes=1000):
super(AlexNet, self).__init__()
self.num_classes = num_classes
self.features = nn.Sequential(
nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=0),
nn.BatchNorm2d(96, eps=1e-4, momentum=0.1, affine=True),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
BinConv2d(96, 256, kernel_size=5, stride=1, padding=2, groups=1),
nn.MaxPool2d(kernel_size=3, stride=2),
BinConv2d(256, 384, kernel_size=3, stride=1, padding=1),
BinConv2d(384, 384, kernel_size=3, stride=1, padding=1, groups=1),
BinConv2d(384, 256, kernel_size=3, stride=1, padding=1, groups=1),
nn.MaxPool2d(kernel_size=3, stride=2),
)
self.classifier = nn.Sequential(
BinConv2d(256 * 6 * 6, 4096, Linear=True),
BinConv2d(4096, 4096, dropout=0.5, Linear=True),
nn.BatchNorm1d(4096, eps=1e-3, momentum=0.1, affine=True),
nn.Dropout(),
nn.Linear(4096, num_classes),
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), 256 * 6 * 6)
x = self.classifier(x)
return x
"""def printbn(self, input, output):
print('Inside ' + self.__class__.__name__ + ' forward')
mean = input[0].mean(dim=0)
var = input[0].var(dim=0)
print(mean)"""
def alexnet(pretrained=False, **kwargs):
r"""AlexNet model architecture from the
`"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = AlexNet(**kwargs)
input_size = 227
if pretrained:
model_path = 'model_list/alexnet.pth.tar'
pretrained_model = torch.load(model_path)
model.load_state_dict(pretrained_model['state_dict'])
#return model
if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
model.features = torch.nn.DataParallel(model.features)
model.cuda()
else:
model = torch.nn.DataParallel(model).cuda()
criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.Adam(model.parameters(), args.lr,
weight_decay=args.weight_decay)
for m in model.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
c = float(m.weight.data[0].nelement())
m.weight.data = m.weight.data.normal_(0, 2.0/c)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data = m.weight.data.zero_().add(1.0)
m.bias.data = m.bias.data.zero_()
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
del checkpoint
else:
print("=> no checkpoint found at '{}'".format(args.resume))
fuse_module(model)
fuse_module_1(model)
#print("after final fusing",model)
cudnn.benchmark = True
"""for key, value in model.named_parameters():
np.set_printoptions(threshold=50000000,formatter={'float_kind':'{:f}'.format})
str1=".txt"
str3=".shape"
str2='./'+'merged_params/'+'_'+key+str1
str4='./'+'merged_params/'+'_'+key+str3+str1
#print(str2)
#print(str4)
#array=np.asarray(value)
s=value.cpu().detach().numpy()
#print(s)
file2=open(str2,"w")
data=s[:]
file2.write(str(data))
#file2.write(s)
file2.close()
file3=open(str4,"w")
#data=s[:]
file3.write(str(value.shape))
#file2.write(s)
file3.close()"""
print('==> Using Pytorch Dataset')
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
traindir = os.path.join(args.data, 'train')
valdir = os.path.join(args.data, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[1./255., 1./255., 1./255.])
torchvision.set_image_backend('PIL')
val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(valdir, transforms.Compose([
transforms.Resize((256, 256)),
transforms.CenterCrop(input_size),
transforms.ToTensor(),
normalize,
])),
batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True)
img = Image.open("/content/XNOR-Net-PyTorch/ImageNet/networks/data/val/n01440764/ILSVRC2012_val_00002138.JPEG")
a = np.array(img)
print(np.max(a),np.min(a),a.dtype)
print(" a is",a.shape)
global bin_op
bin_op = util.BinOp(model)
if args.evaluate:
validate(val_loader, model, criterion)
return
def validate(val_loader, model, criterion):
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
# switch to evaluate mode
model.eval()
#print("after fusing model",model)
end = time.time()
bin_op.binarization()
for i, (input, target) in enumerate(val_loader):
target = target.cuda(async=True)
with torch.no_grad():
input_var = torch.autograd.Variable(input)
target_var = torch.autograd.Variable(target)
#print(input)
# compute output
output = model(input_var)
loss = criterion(output, target_var)
values,indices=torch.max(output,1)
#print("output is",output)
print("maximum value and its indices is",values,indices)
#print("target is",target_var)
# measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
losses.update(loss.data.item(), input.size(0))
top1.update(prec1[0], input.size(0))
top5.update(prec5[0], input.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
i, len(val_loader), batch_time=batch_time, loss=losses,
top1=top1, top5=top5))
bin_op.restore()
print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))
return top1.avg
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, 'model_best.pth.tar')
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def adjust_learning_rate(optimizer, epoch):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
lr = args.lr * (0.1 ** (epoch // 30))
print 'Learning rate:', lr
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
#checkpoint = torch.load(args.resume)```
fusing occurs here
class DummyModule_1(nn.Module):
def __init__(self):
super(DummyModule_1, self).__init__()
def forward(self, x):
# print("Dummy, Dummy.")
return x
def fuse_1(linear, bn):
w = linear.weight
mean = bn.running_mean
var_sqrt = torch.sqrt(bn.running_var + bn.eps)
beta = bn.weight
gamma = bn.bias
if linear.bias is not None:
b = linear.bias
else:
b = mean.new_zeros(mean.shape)
w = w.cuda()
b = b.cuda()
w = (w * beta)/var_sqrt
k = (gamma - beta*mean/var_sqrt)
#v = w * (gamma - beta*mean/var_sqrt)
#v = v.sum()
#v = w * k
v = torch.matmul(w,k)
print(v.size())
print(b.size())
b = b + v
#print(b.size())
fused_linear = nn.Linear(linear.in_features,
linear.out_features)
fused_linear.weight = nn.Parameter(w)
fused_linear.bias = nn.Parameter(b)
return fused_linear
def fuse_module_1(m):
children = list(m.named_children())
c = None
cn = None
global c1
global count
global c18
for name, child in children:
#print("name is",name,"child is",child)
if name == '4' and isinstance(child,nn.Linear):
#count = count+1
#print("count is",count)
#if count == 3:
print("child is",child)
bc = fuse_1(child,c18)
m[4] = bc
m[2] = DummyModule_1()
#else:
#fuse_module_1(child)
elif name =='2' and isinstance(child,nn.BatchNorm1d):
c18 = child
print("c18 is",c18)
fuse_module_1(child)
else:
#fuse_module_1(child)
fuse_module_1(child) ```