The code is
import os
import copy
import torch
import random
import argparse
import warnings
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset
from torchvision.utils import save_image
from torchvision import datasets, transforms
from functorch import make_functional, vmap, jacrev
warnings.filterwarnings("ignore")
CENTER = True
DATA_PATH = 'data/'
class AlexNet(nn.Module):
def __init__(self, channel, num_classes):
super(AlexNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(channel, 64, kernel_size=5, stride=1, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
nn.LocalResponseNorm(4, alpha=0.001 / 9.0, beta=0.75, k=1),
nn.Conv2d(64, 64, kernel_size=5, stride=1, padding=2),
nn.ReLU(inplace=True),
nn.LocalResponseNorm(4, alpha=0.001 / 9.0, beta=0.75, k=1),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
)
self.classifier = nn.Sequential(
nn.Linear(4096, 384),
nn.ReLU(inplace=True),
nn.Linear(384, 192),
nn.ReLU(inplace=True),
nn.Linear(192, num_classes),
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), 4096)
x = self.classifier(x)
return x
def empirical_ntk(fnet_single, params, x1, x2, compute='full'):
# Compute J(x1)
jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)
jac1 = [j.flatten(2) for j in jac1]
# Compute J(x2)
jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)
jac2 = [j.flatten(2) for j in jac2]
# Compute J(x1) @ J(x2).T
einsum_expr = None
if compute == 'full':
einsum_expr = 'Naf,Mbf->NMab'
elif compute == 'trace':
einsum_expr = 'Naf,Maf->NM'
elif compute == 'diagonal':
einsum_expr = 'Naf,Maf->NMa'
else:
assert False
result = torch.stack([torch.einsum(einsum_expr, j1, j2) for j1, j2 in zip(jac1, jac2)])
result = result.sum(0)
return result
def set_random_seeds(seed):
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def get_dataset(data_path, normalize=True):
channel = 3
im_size = (32, 32)
num_classes = 10
mean = [0.4914008, 0.482159 , 0.44653094]
std = np.array([0.24703224, 0.24348514, 0.26158786])
if normalize:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
else:
transform = transforms.Compose([transforms.ToTensor()])
dst_train = datasets.CIFAR10(data_path, train=True, download=True, transform=transform) # no augmentation
dst_test = datasets.CIFAR10(data_path, train=False, download=True, transform=transform)
return channel, im_size, num_classes, mean, std, dst_train, dst_test
def select_images(args, dst_train, num_classes, half_seed=None):
''' organize the real dataset '''
images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))]
images_all = torch.cat(images_all, dim=0)
labels_all = torch.tensor([dst_train[i][1] for i in range(len(dst_train))], dtype=torch.long)
images_all = images_all.to(args.device)
indices_class = [[] for c in range(num_classes)]
for i, lab in enumerate(labels_all):
indices_class[lab].append(i)
return images_all, labels_all, indices_class
# CUDA_VISIBLE_DEVICES=1
def main():
parser = argparse.ArgumentParser(description='Parameter Processing')
parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset')
parser.add_argument('--save_name', type=str, default='', help='additional surffix')
parser.add_argument('--model', type=str, default='AlexNet', help='model')
# DC parameters
parser.add_argument('--ipc', type=int, default=50, help='image(s) per class')
parser.add_argument('--Iteration', type=int, default=5000, help='training iterations')
parser.add_argument('--lr_img', type=float, default=4e-2, help='learning rate for updating synthetic images')
parser.add_argument('--batch_real', type=int, default=128, help='batch size for real data')
parser.add_argument('--batch_syn', type=int, default=32, help='batch size for real data')
parser.add_argument('--init', type=str, default='noise', help='noise/real: initialize synthetic images from random noise or randomly sampled real images.')
parser.add_argument('--decay_step', type=int, default=1)
parser.add_argument('--half_seed', type=int, default=None)
args = parser.parse_args()
eval_it_pool = np.arange(0, args.Iteration+1, args.Iteration // args.decay_step).tolist()[1:]
print(eval_it_pool)
args.clean = True
channel, im_size, num_classes, mean, std, dst_train, dst_test = get_dataset(DATA_PATH)
args.clean = False
images_all, labels_all, indices_class = select_images(args, dst_train, num_classes, half_seed=args.half_seed)
org_testloader = torch.utils.data.DataLoader(dst_test, batch_size=256, shuffle=False, num_workers=2)
for c in range(num_classes):
print('class c = %d: %d real images'%(c, len(indices_class[c])))
for ch in range(channel):
print('real images channel %d, mean = %.4f, std = %.4f'%(ch, torch.mean(images_all[:, ch]), torch.std(images_all[:, ch])))
''' initialize the synthetic data '''
image_syn = torch.randn(size=(num_classes*args.ipc, channel, im_size[0], im_size[1]), dtype=torch.float,
requires_grad=True, device=args.device)
label_syn = torch.arange(num_classes, device=args.device, dtype=torch.long).reshape(-1, 1).repeat(1, args.ipc).reshape(-1)
if args.init == 'real':
for c in range(num_classes):
np.random.seed(42)
idx_shuffle = np.random.permutation(indices_class[c])[:args.ipc]
image_syn.data[c*args.ipc:(c+1)*args.ipc] = images_all[idx_shuffle].detach().data
''' training '''
optimizer_img = torch.optim.Adam([image_syn,], lr=args.lr_img) # optimizer_img for synthetic data
optimizer_img.zero_grad()
# tqdm_range = trange(1, args.Iteration+1, desc='Loss', leave=True)
criterion = torch.nn.MSELoss()
for it in range(1, args.Iteration+1):
''' Evaluate synthetic data '''
''' Train synthetic data '''
net = AlexNet().to(args.device)
# net = ConvNetNTK(channel=channel, num_classes=num_classes,
# net_width=128, net_depth=3, net_act='relu',
# net_norm='none', net_pooling='avgpooling', im_size=im_size).to(args.device)
net.train()
# def reduce_logits(x):
# out = net(x)
# return torch.sum(out, dim=1) / (num_classes ** (1/2))
''' update synthetic data '''
batch_real_id = np.random.choice(len(images_all), args.batch_real, replace=False)
images_real_batch = images_all[batch_real_id]
label_real_batch = F.one_hot(labels_all[batch_real_id], num_classes=num_classes).to(args.device).float()
batch_syn_id = np.random.choice(len(image_syn), args.batch_syn, replace=False)
images_syn_batch = image_syn[batch_syn_id]
label_syn_batch = F.one_hot(label_syn[batch_syn_id], num_classes=num_classes).to(args.device).float()
if CENTER:
label_syn_batch -= 1/ num_classes
label_real_batch -= 1/ num_classes
fnet, params = make_functional(net)
def fnet_single(params, x):
return fnet(params, x.unsqueeze(0)).squeeze(0)
K_ss = empirical_ntk(fnet_single, params, images_syn_batch, images_syn_batch, 'trace')
K_ts = empirical_ntk(fnet_single, params, images_real_batch, images_syn_batch, 'trace')
K_ss_reg = (K_ss + 1e-6 * torch.trace(K_ss) * torch.eye(K_ss.shape[0], device=args.device) / K_ss.shape[0])
solve = torch.linalg.solve(K_ss_reg, label_syn_batch)
# solve, _ = torch.solve(label_syn_batch, K_ss_reg)
pred = torch.mm(K_ts, solve).to(args.device)
loss = criterion(pred, label_real_batch)
acc = torch.mean((torch.argmax(pred, dim=1) == torch.argmax(label_real_batch, dim=1)).float())
optimizer_img.zero_grad()
loss.backward()
optimizer_img.step()