I wanna do channel pruning using PyTorch. (usnig 0.4.0)
When I run the retraining mode, I got this error.
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
Here is my code.
from __future__ import print_function
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
from torch.autograd import Variable
from utils import progress_bar
import os
import argparse
import vgg16
import copy
parser = argparse.ArgumentParser(description='PyTorch CIFAR10')
parser.add_argument('--mode', default=1, type=int, help='load models or make models') # mode=0 make, mode=1 load
parser.add_argument('--cal', default=1, type=int, help='calculate parameters and MAC ops') # cal=0 no calculation, cal=1 calculation
parser.add_argument('--nl', default=16, type=int, help='number of CONV and FC layers') # For VGG16, nl=16
parser.add_argument('--mc', default=512, type=int, help='maximum channel depth') # For VGG16, mc=512
parser.add_argument('--chpr', default=1, type=int, help='pruning or not') # pr=0 no pruning, pr=1 pruning
parser.add_argument('--convpr', default=1, type=int, help='conv pruning')
parser.add_argument('--fcpr', default=1, type=int, help='fc pruning')
parser.add_argument('--infer', default=1, type=int, help='without retraining or with retraining')
parser.add_argument('--pr', default=0.29, type=float, help='pruning ratio')
parser.add_argument('--lr', default=0.004, type=float, help='learning rate')
parser.add_argument('--bs', default=512, type=int, help='batch size')
args = parser.parse_args()
use_cuda = torch.cuda.is_available()
best_acc = 0 # best test accuracy
use_cuda = torch.cuda.is_available()
transform_train = transforms.Compose([transforms.RandomCrop(32,padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
transform_test = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
cifar_train = dset.CIFAR10("./", train=True, transform=transform_train, target_transform=None, download=True)
cifar_test = dset.CIFAR10("./", train=False, transform=transform_test, target_transform=None, download=True)
train_loader = torch.utils.data.DataLoader(cifar_train, batch_size=args.bs, shuffle=True, num_workers=2, drop_last=False)
test_loader = torch.utils.data.DataLoader(cifar_test, batch_size=10000, shuffle=False, num_workers=2, drop_last=False)
# Model
if args.mode == 1:
# Load checkpoint.
print('==> Resuming from checkpoint..')
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
checkpoint = torch.load('./checkpoint/ckpt_20190312_cifar10_vgg16.t7') # Load your own neural network models
net = checkpoint['net']
best_acc = checkpoint['acc']
else:
print('==> Making model')
net = vgg16.VGG()
mask = vgg16.VGG()
if use_cuda:
net.cuda()
mask.cuda()
net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
mask = torch.nn.DataParallel(mask, device_ids=range(torch.cuda.device_count()))
cudnn.benchmark = True
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
# Retraining
def retrain(epoch, mask):
print('\nEpoch: %d' % epoch)
net.train()
train_loss = 0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(train_loader):
if use_cuda:
inputs, targets = inputs.cuda(), targets.cuda()
optimizer.zero_grad()
inputs, targets = Variable(inputs), Variable(targets)
outputs = net(inputs)
loss = criterion(outputs, targets)
loss.backward()
masking(net, mask)
optimizer.step()
train_loss += loss.data
_, predicted = torch.max(outputs.data, 1)
total += targets.size(0)
correct += predicted.eq(targets.data).cpu().sum()
progress_bar(batch_idx, len(train_loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (train_loss/(batch_idx+1), 100.*float(correct)/float(total), correct, total))
def test(mask):
global best_acc
net.eval()
test_loss = 0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(test_loader):
if use_cuda:
inputs, targets = inputs.cuda(), targets.cuda()
inputs, targets = Variable(inputs), Variable(targets)
masking(net, mask)
outputs = net(inputs)
loss = criterion(outputs, targets)
test_loss += loss.data
_, predicted = torch.max(outputs.data, 1)
total += targets.size(0)
correct += predicted.eq(targets.data).cpu().sum()
progress_bar(batch_idx, len(test_loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (test_loss/(batch_idx+1), 100.*float(correct)/float(total), correct, total))
acc = 100.*float(correct)/float(total)
return acc
def channel_pruning_mask(net, mask):
total_num_params_pruned = 0
total_num_mac_ops_pruned = 0
filter_sum = torch.zeros((args.nl, args.mc))
idx = 0
fake_input = torch.ones([1,3,32,32]).cuda()
if args.convpr == 1:
for i in range(0, len(net.module.features)):
fake_input_size = fake_input.size()
if net.module.features[i].__class__.__name__ == 'Conv2d':
weight_size = net.module.features[i].weight.size()
bias_size = net.module.features[i].bias.size()
for j in range(0, weight_size[0]):
filter_sum[idx][j] = torch.sum(torch.abs(net.module.features[i].weight[j])) # i = layer index, j = channel index
sorted_values_conv, sorted_indices_conv = torch.topk(filter_sum[idx][0:weight_size[0]], round((1 - args.pr) * weight_size[0]))
for j in range(0, weight_size[0]):
if j in sorted_indices_conv:
params_mask = list(mask.module.features[i].parameters())
params_mask[0][j] = torch.ones(weight_size[1], weight_size[2], weight_size[3], requires_grad=True)
else:
params_mask = list(mask.module.features[i].parameters())
params_mask[0][j] = torch.zeros(weight_size[1], weight_size[2], weight_size[3], requires_grad=True)
idx += 1
fake_output = net.module.features[i](fake_input)
fake_input = fake_output
fake_input = fake_input.view(fake_input.size(0), -1)
else:
pass
if args.fcpr == 1:
for i in range(0, len(net.module.classifier)):
if net.module.classifier[i].__class__.__name__ == 'Linear':
weight_size = net.module.classifier[i].weight.size()
bias_size = net.module.classifier[i].bias.size()
for j in range(0, weight_size[1]):
filter_sum[idx][j] = torch.sum(torch.abs(net.module.classifier[i].weight[:,j]))
sorted_values_fc, sorted_indices_fc = torch.topk(filter_sum[idx][0:weight_size[1]], round((1 - args.pr) * weight_size[1]))
for j in range(0, weight_size[1]):
if j in sorted_indices_fc:
params_mask = list(mask.module.classifier[i].parameters())
params_mask[0][:,j] = torch.ones(weight_size[0], requires_grad=True)
else:
params_mask = list(mask.module.classifier[i].parameters())
params_mask[0][:,j] = torch.zeros(weight_size[0], requires_grad=True)
idx += 1
fake_output = net.module.classifier[i](fake_input)
fake_input = fake_output
else:
pass
return net, mask
def masking(net, mask):
for i in range(0, len(net.module.features)):
if net.module.features[i].__class__.__name__ == 'Conv2d':
weight_size = net.module.features[i].weight.size()
for j in range(0, weight_size[0]):
params_mask = list(mask.module.features[i].parameters())
params_net = list(net.module.features[i].parameters())
params_net[0][j] = torch.mul(params_net[0][j], params_mask[0][j])
"""
if args.infer == 0:
params_net[0][j].grad = torch.mul(params_net[0][j].grad, params_mask[0][j].grad)
else:
pass
"""
for i in range(0, len(net.module.classifier)):
if net.module.classifier[i].__class__.__name__ == 'Linear':
weight_size = net.module.classifier[i].weight.size()
for j in range(0, weight_size[1]):
params_mask = list(mask.module.classifier[i].parameters())
params_net = list(net.module.classifier[i].parameters())
params_net[0][:,j] = torch.mul(params_net[0][:,j], params_mask[0][:,j])
"""
if args.infer == 0:
params_net[0][:,j].grad = torch.mul(params_net[0][:,j].grad, params_mask[0][:,j].grad)
else:
pass
"""
# Channel pruning like He_ICCV_2017 (I think criterion is much simpler than He_ICCV_2017)
if args.chpr == 1:
channel_pruning_mask(net, mask)
else:
pass
# Without retraining or with retraining
if args.infer == 1:
print("<< Output channel pruning without retraining >> \n")
test(mask)
else:
print("<< Output channel pruning with retraining >> \n")
for epoch in range(0,1):
retrain(epoch, mask)
test(mask)
Which operations are inplace operations??