Fusing of Layers

How did you implement it in the end?
Could you post your code and where you would fuse parameters to save operations?

sorry for late reply,

import os
import torch.nn as nn
import torch
import os
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import torch.nn.functional as F
import argparse
import shutil
import time
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import numpy as np
import numpy
import util
from PIL import Image
import torch.utils.model_zoo as model_zoo
import torch.nn.functional as F
# set the seed
torch.manual_seed(1)
torch.cuda.manual_seed(1)

import sys
import gc

parser = argparse.ArgumentParser(description='Alexnet')
parser.add_argument('--arch', '-a', metavar='ARCH', default='alexnet',
                    help='model architecture (default: alexnet)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                    help='evaluate model on validation set')
parser.add_argument('-m', '--trained_model', default='weights/alexnet.baseline.pytorch.pth.tar',
                    type=str, help='Trained state_dict file path to open')
parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,
                    metavar='LR', help='initial learning rate')
parser.add_argument('--weight-decay', '--wd', default=1e-5, type=float,
                    metavar='W', help='weight decay (default: 1e-5)')  
parser.add_argument('--data', metavar='DATA_PATH', default='./data/',
                    help='path to imagenet data (default: ./data/)') 
parser.add_argument('-b', '--batch-size', default=256, type=int,
                    metavar='N', help='mini-batch size (default: 256)')  
parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
                    help='number of data loading workers (default: 8)') 
parser.add_argument('--print-freq', '-p', default=10, type=int,
                    metavar='N', help='print frequency (default: 10)')                                       

args = parser.parse_args()

__all__ = ['AlexNet', 'alexnet']

best_prec1 = 0
bin_op = None
c1 =0 
count =0 
count1 = 0

class DummyModule(nn.Module):
    def __init__(self):
        super(DummyModule, self).__init__()

    def forward(self, x):
        # print("Dummy, Dummy.")
        return x


def fuse(conv, bn):
    
    w = conv.weight
    #print(w)
    mean = bn.running_mean
    var_sqrt = torch.sqrt(bn.running_var + bn.eps)

    beta = bn.weight
    gamma = bn.bias

    if conv.bias is not None:
        b = conv.bias
    else:
        b = mean.new_zeros(mean.shape)
    w1 = (beta/var_sqrt)
    w1 = w1.reshape([conv.out_channels, 1, 1, 1])
    print(w1.size())
    w = w * w1
    #w = torch.matmul(w,w1)
    #w = w * (beta / var_sqrt).reshape([conv.out_channels, 1, 1, 1])
    b = (b - mean)/var_sqrt * beta + gamma
    fused_conv = nn.Conv2d(conv.in_channels,
                         conv.out_channels,
                         conv.kernel_size,
                         conv.stride,
                         conv.padding,
                         bias=True)                    
    fused_conv.weight = nn.Parameter(w)
    fused_conv.bias = nn.Parameter(b)
    return fused_conv


def fuse_module(m):
    children = list(m.named_children())
    c = None
    cn = None

    for name, child in children:
        #print("name is",name,"child is",child)
        if isinstance(child, nn.BatchNorm2d):
            bc = fuse(c, child)
            m._modules[cn] = bc
            #print('hi',m._modules['0'])
            m._modules[name] = DummyModule()
            c = None
        elif isinstance(child, nn.Conv2d):
            c = child
            cn = name
        elif isinstance(child,BinConv2d):
            break
        else:
            #print(child)
            fuse_module(child) 

class DummyModule_1(nn.Module):
    def __init__(self):
        super(DummyModule_1, self).__init__()

    def forward(self, x):
        # print("Dummy, Dummy.")
        return x


def fuse_1(linear, bn):
    w = linear.weight
    mean = bn.running_mean
    var_sqrt = torch.sqrt(bn.running_var + bn.eps)
    beta = bn.weight
    gamma = bn.bias
    

    if linear.bias is not None:
        b = linear.bias
    else:
        b = mean.new_zeros(mean.shape)

    w = w.cuda()
    b = b.cuda()
    
    w = (w * beta)/var_sqrt
    k = (gamma - beta*mean/var_sqrt)
    
    
    #v = w * (gamma - beta*mean/var_sqrt)
    #v = v.sum()
    #v = w * k
    v = torch.matmul(w,k)
    print(v.size())
    print(b.size())
    b = b + v
    #print(b.size())
    
    fused_linear = nn.Linear(linear.in_features,
                         linear.out_features)
                                             
    fused_linear.weight = nn.Parameter(w)
    fused_linear.bias = nn.Parameter(b)
    return fused_linear


def fuse_module_1(m):
    children = list(m.named_children())
    c = None
    cn = None
    global c1
    global count
    global c18

    for name, child in children:
        #print("name is",name,"child is",child)
        
         

        if name == '4' and isinstance(child,nn.Linear):
          #count = count+1 
          #print("count is",count)
          
          #if count == 3:
            print("child is",child)
            bc = fuse_1(child,c18)
            m[4] = bc
            m[2] = DummyModule_1()
            

        #else:
            #fuse_module_1(child)

        
        elif name =='2' and isinstance(child,nn.BatchNorm1d):
          c18 = child
          print("c18 is",c18)
          fuse_module_1(child)
        
          
        else:
            #fuse_module_1(child)
          fuse_module_1(child)     

class BinActive(torch.autograd.Function):
    '''
    Binarize the input activations and calculate the mean across channel dimension.
    '''
    def forward(self, input):
        self.save_for_backward(input)
        size = input.size()
        input = input.sign()
        return input

    def backward(self, grad_output):
        input, = self.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input.ge(1)] = 0
        grad_input[input.le(-1)] = 0
        return grad_input

class BinConv2d(nn.Module): # change the name of BinConv2d
    def __init__(self, input_channels, output_channels,
            kernel_size=-1, stride=-1, padding=-1, groups=1, dropout=0,
            Linear=False):
        super(BinConv2d, self).__init__()
        self.layer_type = 'BinConv2d'
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dropout_ratio = dropout

        if dropout!=0:
            self.dropout = nn.Dropout(dropout)
        self.Linear = Linear
        if not self.Linear:
            self.bn = nn.BatchNorm2d(input_channels, eps=1e-4, momentum=0.1, affine=True)
            self.conv = nn.Conv2d(input_channels, output_channels,
                    kernel_size=kernel_size, stride=stride, padding=padding, groups=groups)
        else:
            self.bn = nn.BatchNorm1d(input_channels, eps=1e-4, momentum=0.1, affine=True)
            self.linear = nn.Linear(input_channels, output_channels)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        global count1
        x = self.bn(x)
        x = BinActive()(x)
        if self.dropout_ratio!=0:
            x = self.dropout(x)
        if not self.Linear:
            x = self.conv(x)
        else:
            x = self.linear(x)
        x = self.relu(x)
        print(x.size())

        """if x.size() == (1,4096):
          count1 = count1+1

        if count1 == 2:
          y = x.cpu().detach().numpy()
          np.savetxt("1_fused_relu.output.csv",y,fmt='%.6f',delimiter = ',')"""

        return x

class AlexNet(nn.Module):

    def __init__(self, num_classes=1000):
        super(AlexNet, self).__init__()
        self.num_classes = num_classes
        self.features = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=0),
            nn.BatchNorm2d(96, eps=1e-4, momentum=0.1, affine=True),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            BinConv2d(96, 256, kernel_size=5, stride=1, padding=2, groups=1),
            nn.MaxPool2d(kernel_size=3, stride=2),
            BinConv2d(256, 384, kernel_size=3, stride=1, padding=1),
            BinConv2d(384, 384, kernel_size=3, stride=1, padding=1, groups=1),
            BinConv2d(384, 256, kernel_size=3, stride=1, padding=1, groups=1),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.classifier = nn.Sequential(
            BinConv2d(256 * 6 * 6, 4096, Linear=True),
            BinConv2d(4096, 4096, dropout=0.5, Linear=True),
            nn.BatchNorm1d(4096, eps=1e-3, momentum=0.1, affine=True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), 256 * 6 * 6)
        x = self.classifier(x)
        return x

"""def printbn(self, input, output):
    print('Inside ' + self.__class__.__name__ + ' forward')
    mean = input[0].mean(dim=0)
    var = input[0].var(dim=0)
    print(mean)"""



def alexnet(pretrained=False, **kwargs):
    r"""AlexNet model architecture from the
    `"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = AlexNet(**kwargs)
    input_size = 227


    if pretrained:
        model_path = 'model_list/alexnet.pth.tar'
        pretrained_model = torch.load(model_path)
        model.load_state_dict(pretrained_model['state_dict'])
    #return model

    if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
        model.features = torch.nn.DataParallel(model.features)
        model.cuda()
    else:
        model = torch.nn.DataParallel(model).cuda()
    
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.Adam(model.parameters(), args.lr,
                                weight_decay=args.weight_decay)

    for m in model.modules():
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
            c = float(m.weight.data[0].nelement())
            m.weight.data = m.weight.data.normal_(0, 2.0/c)
        elif isinstance(m, nn.BatchNorm2d):
            m.weight.data = m.weight.data.zero_().add(1.0)
            m.bias.data = m.bias.data.zero_()

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
            del checkpoint
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    
    fuse_module(model)
    fuse_module_1(model)
    #print("after final fusing",model)
    
    cudnn.benchmark = True

    """for key, value in model.named_parameters():
      np.set_printoptions(threshold=50000000,formatter={'float_kind':'{:f}'.format})
      str1=".txt"
      str3=".shape"
      str2='./'+'merged_params/'+'_'+key+str1
      str4='./'+'merged_params/'+'_'+key+str3+str1
      #print(str2)
      #print(str4)
          #array=np.asarray(value)
      s=value.cpu().detach().numpy()
        #print(s)
      file2=open(str2,"w")
      data=s[:]
      file2.write(str(data))
        #file2.write(s)
      file2.close()

      file3=open(str4,"w")
        #data=s[:]
      file3.write(str(value.shape))
        #file2.write(s)
      file3.close()"""
        
      

    print('==> Using Pytorch Dataset')
    import torchvision
    import torchvision.transforms as transforms
    import torchvision.datasets as datasets
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                std=[1./255., 1./255., 1./255.])

    torchvision.set_image_backend('PIL')

    val_loader = torch.utils.data.DataLoader(
                datasets.ImageFolder(valdir, transforms.Compose([
                    transforms.Resize((256, 256)),
                    transforms.CenterCrop(input_size),
                    transforms.ToTensor(),
                    normalize,
                    ])),
                batch_size=args.batch_size, shuffle=False,
                num_workers=args.workers, pin_memory=True)

    img = Image.open("/content/XNOR-Net-PyTorch/ImageNet/networks/data/val/n01440764/ILSVRC2012_val_00002138.JPEG")
    a = np.array(img)
    print(np.max(a),np.min(a),a.dtype)
    print(" a is",a.shape)

    global bin_op
    bin_op = util.BinOp(model)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return    

def validate(val_loader, model, criterion):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    
    model.eval()
    #print("after fusing model",model)
    
    end = time.time()
    bin_op.binarization()
    
    for i, (input, target) in enumerate(val_loader):
        target = target.cuda(async=True)
        with torch.no_grad():
            input_var = torch.autograd.Variable(input)
            target_var = torch.autograd.Variable(target)
        #print(input)

        # compute output
        output = model(input_var)
        loss = criterion(output, target_var)
        values,indices=torch.max(output,1)
        #print("output is",output)
        print("maximum value and its indices is",values,indices)
        #print("target is",target_var)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
        losses.update(loss.data.item(), input.size(0))
        top1.update(prec1[0], input.size(0))
        top5.update(prec5[0], input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Test: [{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                   i, len(val_loader), batch_time=batch_time, loss=losses,
                   top1=top1, top5=top5))
    bin_op.restore()

    print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
          .format(top1=top1, top5=top5))

    return top1.avg


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = args.lr * (0.1 ** (epoch // 30))
    print 'Learning rate:', lr
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res
 

    #checkpoint = torch.load(args.resume)```
fusing occurs here

class DummyModule_1(nn.Module):
    def __init__(self):
        super(DummyModule_1, self).__init__()

    def forward(self, x):
        # print("Dummy, Dummy.")
        return x


def fuse_1(linear, bn):
    w = linear.weight
    mean = bn.running_mean
    var_sqrt = torch.sqrt(bn.running_var + bn.eps)
    beta = bn.weight
    gamma = bn.bias
    

    if linear.bias is not None:
        b = linear.bias
    else:
        b = mean.new_zeros(mean.shape)

    w = w.cuda()
    b = b.cuda()
    
    w = (w * beta)/var_sqrt
    k = (gamma - beta*mean/var_sqrt)
    
    
    #v = w * (gamma - beta*mean/var_sqrt)
    #v = v.sum()
    #v = w * k
    v = torch.matmul(w,k)
    print(v.size())
    print(b.size())
    b = b + v
    #print(b.size())
    
    fused_linear = nn.Linear(linear.in_features,
                         linear.out_features)
                                             
    fused_linear.weight = nn.Parameter(w)
    fused_linear.bias = nn.Parameter(b)
    return fused_linear


def fuse_module_1(m):
    children = list(m.named_children())
    c = None
    cn = None
    global c1
    global count
    global c18

    for name, child in children:
        #print("name is",name,"child is",child)
        
         

        if name == '4' and isinstance(child,nn.Linear):
          #count = count+1 
          #print("count is",count)
          
          #if count == 3:
            print("child is",child)
            bc = fuse_1(child,c18)
            m[4] = bc
            m[2] = DummyModule_1()
            

        #else:
            #fuse_module_1(child)

        
        elif name =='2' and isinstance(child,nn.BatchNorm1d):
          c18 = child
          print("c18 is",c18)
          fuse_module_1(child)
        
          
        else:
            #fuse_module_1(child)
          fuse_module_1(child) ```