In case the code of the network is needed, here it is:
#!/usr/bin/env python
# coding: utf-8
import torch
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn as nn
import random
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from datetime import datetime
import os
from matplotlib.pyplot import figure
import shutil
figure(figsize=(12, 10), dpi=120)
plt.rcParams['image.cmap'] = 'gray'
batch_size = 100
NUMBER_OF_CLASSES = 10
NUMBER_OF_CLUSTER_CENTROIDS = 32
NO_TRAINING = False
device=torch.device('mps')
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
def augment_data(data):
transform = transforms.Compose([
transforms.Normalize((0.5,), (0.5,))])
transform_aug = transforms.Compose([transforms.RandomResizedCrop(size=28,scale=(0.2, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.Normalize((0.5,), (0.5,))])
return transform(data), transform_aug(data)
class discriminator(nn.Module):
def __init__(self):
super(discriminator, self).__init__()
self.main = nn.Sequential(
nn.Linear(784, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 128),
nn.LeakyReLU(0.2),
)
self.real_or_fake_layer = nn.Sequential(
nn.Linear(128, 1),
)
self.image_feature_layer = nn.Sequential(
nn.Linear(128, NUMBER_OF_CLUSTER_CENTROIDS),
nn.Sigmoid()
)
self.centroids = nn.Linear(NUMBER_OF_CLASSES, NUMBER_OF_CLUSTER_CENTROIDS)
def forward(self, input, eye):
common_part = self.main(input)
real_or_fake = self.real_or_fake_layer(common_part)
image_feature = self.image_feature_layer(common_part)
cluster_centroid = self.centroids(eye)
return real_or_fake, image_feature, cluster_centroid
class generator(nn.Module):
def __init__(self):
super(generator, self).__init__()
self.main = nn.Sequential(
nn.Linear(138,512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, 784),
nn.Tanh(),
)
def forward(self, input):
step = self.main(input)
return step
def show_images(images, epoch,folder_path):
sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
for index, image in enumerate(images):
plt.subplot(sqrtn, sqrtn, index+1)
path = folder_path + str(epoch) + ".png"
plt.imshow(image.reshape(28, 28))
plt.savefig(path)
transform = transforms.Compose([transforms.ToTensor()])
train_set = datasets.MNIST('./MNIST_DATA/train/', train=True, download=True,transform=transform)
test_set = datasets.MNIST('./MNIST_DATA/test/', train=False, download=True,transform=transform)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, drop_last=False)
CrosEntr = nn.CrossEntropyLoss()
from scipy.optimize import linear_sum_assignment
def get_stat(generated_label, num_of_gt_labels, gt_labels):
c_stat = np.zeros([num_of_gt_labels,num_of_gt_labels])
for i in range(len(gt_labels)):
gt_idx = int(gt_labels[i])
c_stat[gt_idx][generated_label[i]] += 1
return c_stat
def get_match(stat):
_, col_ind = linear_sum_assignment(stat.max()-stat)
return col_ind
def get_acc(stat, col_ind, over):
tot = 0
for i in range(stat.shape[0]):
tot += stat[i][col_ind[i]]
return tot/(np.sum(stat)/over)
def get_nmi(stat):
n,m = stat.shape
pij = stat/np.sum(stat)
pi = np.sum(pij, 1)
pj = np.sum(pij, 0)
enti = sum([-pi[i]*np.log2(pi[i]+1e-6) for i in range(n)])
entj = sum([-pj[i]*np.log2(pj[i]+1e-6) for i in range(m)])
mi = 0
for i in range(n):
for j in range(m):
mi += pij[i][j]*(np.log2(pij[i][j]/(pi[i]*pj[j]+1e-6)+1e-6))
return mi/max(enti, entj)
epochs=200
lr = 0.0005
TEMP = 0.1
HYPER_PARA_ADV = 0.5
HYPER_PARA_INFO_LOSS = 5
HYPER_PARA_CONTR = 0.2
HYPER_PARA_ENTROPY = 0.2
now = datetime.now()
print("now =", now)
dt_string = now.strftime("%Y_%m_%d_%H%M%S")
print("date and time =", dt_string)
# Check whether the specified path exists or not
isExist = os.path.exists(dt_string)
if not isExist and not NO_TRAINING:
# Create a new directory because it does not exist
save_folder_path = "./results/" + dt_string + "/"
os.makedirs(save_folder_path[2:-1])
shutil.copy2("Sandbox_CIFAR10.py",save_folder_path[2:-1])
# TRAINING
#############################################################################################
Gen = generator().to(device)
Dis = discriminator().to(device)
print(Gen)
print(Dis)
optimizerG = optim.Adam(Gen.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerD = optim.Adam(Dis.parameters(), lr=lr, betas=(0.5, 0.999))
best_acc = 0
best_nmi = 0
#############################################################################################
eye = torch.eye(NUMBER_OF_CLASSES).to(device)
for epoch in range(epochs):
epoch += 1
if NO_TRAINING:
print("There won't be any training, go to top cell to change NO_TRAINING to False")
break
for times, data in enumerate(train_loader):
times += 1
# 1. Prepare Inputs
real_inputs, aug_real_inputs = augment_data(data[0])
real_inputs, aug_real_inputs = real_inputs.to(device), aug_real_inputs.to(device)
real_inputs, aug_real_inputs = real_inputs.view(-1, 784), aug_real_inputs.view(-1,784)
noise = torch.randn(real_inputs.shape[0], 128)
noise = noise.to(device)
rand_c = torch.zeros(real_inputs.shape[0], NUMBER_OF_CLASSES).to(device)
rand_idx = [i for i in range(NUMBER_OF_CLASSES)]
random.shuffle(rand_idx)
for i, element in enumerate(rand_c):
element[rand_idx[i%10]] = 1
noise = torch.cat((noise, rand_c),1)
fake_inputs = Gen(noise)
# 2. Train Discriminator
# 2.1 Pass Images
optimizerD.zero_grad()
real_outputs, real_img_feat, real_class_emb = Dis(real_inputs, eye)
fake_outputs, fake_img_feat, fake_class_emb = Dis(fake_inputs, eye)
# 2.2 Adversarial Loss
real_adv_loss = HYPER_PARA_ADV* (nn.ReLU()(1 - real_outputs).mean())
fake_adv_loss_D = HYPER_PARA_ADV*(nn.ReLU()(1 + fake_outputs).mean())
# 2.3 Info Loss
f = F.normalize(fake_img_feat, p=2, dim=1)
c = F.normalize(fake_class_emb, p=2, dim=1)
class_dist = torch.cat([torch.matmul(f, c[i].unsqueeze(-1))/TEMP for i in range(NUMBER_OF_CLASSES)],1)
info_loss = HYPER_PARA_INFO_LOSS* CrosEntr(class_dist, torch.argmax(rand_c, 1))
# # 2.4 Real Image Contrastive Loss
# if add_con_loss:
labels = torch.cat([torch.arange(batch_size) for i in range(2)], dim=0)
labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
labels.to(device)
_,real_aug_img_feat, _ = Dis(aug_real_inputs, eye)
feat = torch.cat([real_img_feat, real_aug_img_feat], 0)
feat = F.normalize(feat, p=2, dim=1)
similarity_matrix = torch.matmul(feat, feat.T)
similarity_matrix = similarity_matrix.detach().cpu()
mask = torch.eye(feat.shape[0], dtype=torch.bool)
labels = labels[~mask].view(labels.shape[0], -1)
similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1).to(device)
negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1).to(device)
logits = torch.cat([positives, negatives], dim=1)
labels = torch.zeros(logits.shape[0], dtype=torch.long).to(device)
logits = logits/TEMP
contrastive_loss = HYPER_PARA_CONTR* CrosEntr(logits, labels)
# # 2.5 Entropy Regularizations
rr = F.normalize(real_img_feat, p=2, dim=1)
class_dist_real = torch.cat([torch.matmul(rr, c[i].unsqueeze(-1))/TEMP for i in range(NUMBER_OF_CLASSES)], 1)
class_dist_real_sm = F.softmax(class_dist_real, 1)
entropy_reg_1 = -1*HYPER_PARA_ENTROPY*(class_dist_real_sm*torch.log(class_dist_real_sm)).sum(1).mean()
entropy_reg_2 = 1*HYPER_PARA_ENTROPY*(class_dist_real_sm.mean(0)*torch.log(class_dist_real_sm.mean(0))).sum()
"""ZERO GRAD HAS BEEN MOVED UP"""
# 2.6 Backward Pass
D_loss = real_adv_loss + fake_adv_loss_D + info_loss + contrastive_loss + entropy_reg_1 + entropy_reg_2
D_loss.backward()
optimizerD.step()
noise = torch.randn(real_inputs.shape[0], 128)
noise = noise.to(device)
rand_c = torch.zeros(real_inputs.shape[0], NUMBER_OF_CLASSES).to(device)
rand_idx = [i for i in range(NUMBER_OF_CLASSES)]
random.shuffle(rand_idx)
for i, element in enumerate(rand_c):
element[rand_idx[i%10]] = 1
noise = torch.cat((noise, rand_c),1)
fake_inputs = Gen(noise)
optimizerG.zero_grad()
# 4 Train Generator
fake_outputs, fake_img_feat, fake_class_emb_2 = Dis(fake_inputs, eye)
# 4.1 Adversarial Loss
fake_adv_loss_G = -HYPER_PARA_ADV*fake_outputs.mean()
# 4.2 Info Loss
f_2 = F.normalize(fake_img_feat, p=2, dim=1)
c_2 = F.normalize(fake_class_emb_2, p=2, dim=1)
class_dist_2 = torch.cat([torch.matmul(f_2, c_2[i].unsqueeze(-1))/TEMP for i in range(NUMBER_OF_CLASSES)],1)
info_loss_2 = HYPER_PARA_INFO_LOSS * CrosEntr(class_dist_2, torch.argmax(rand_c, 1))
G_loss = fake_adv_loss_G + info_loss_2
"""ZERO HAS BEEN MOVED TO THE TOP"""
G_loss.backward()
optimizerG.step()
# TEST PERFORMANCE
pred_c = []
pred_c_cpu = []
real_c = []
with torch.no_grad():
for image, label in test_loader:
real_img_test = image.to(device).view(-1, 784)
_, feat_test, class_emb_test = Dis(real_img_test, eye)
f_test = F.normalize(feat_test, p=2, dim=1)
c_test = F.normalize(class_emb_test, p=2, dim=1)
class_dist_test = torch.cat([torch.from_numpy(np.dot(f_test.cpu(),c_test[i].cpu())).float().to(device).unsqueeze(-1)/TEMP for i in range(NUMBER_OF_CLASSES)],1)
# class_dist_candas = torch.cat([torch.matmul(f, c[i]).unsqueeze(-1)/TEMP for i in range(NUMBER_OF_CLASSES)], 1) ### WRONG: FOR SOME REASON IT CALCULATES WRONG
pred_c += list(torch.argmax(class_dist_test, 1))
pred_c_cpu += list(torch.argmax(class_dist_test.cpu(), 1))
real_c += list(label.cpu().numpy())
c_table = get_stat(pred_c, NUMBER_OF_CLASSES, real_c)
idx_map = get_match(c_table)
cur_acc = get_acc(c_table, idx_map, 1)
cur_nmi = get_nmi(c_table[:10,:])
if cur_acc > best_acc:
best_acc = cur_acc
best_nmi = cur_nmi
torch.save(Gen, save_folder_path + 'Best_Generator.pth')
torch.save(Dis, save_folder_path + 'Best_Discriminator.pth')
with open(save_folder_path+"0_Test_Performance.txt",'a') as f:
f.write("Epoch: {}, Test Accuracy: {:.3f}, Test NMI: {:.3f}\n".format(epoch, cur_acc, cur_nmi))
imgs_numpy = (fake_inputs.data.cpu().numpy()+1.0)/2.0
print(imgs_numpy.size)
print(imgs_numpy[:16].size)
show_images(imgs_numpy[:16], epoch, save_folder_path)