Hi ,
Thanks a lot for your reply. Please find my code snippet below.
import torch
from networks.base_cnn import Net
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision
import random
from utils.distances import bhattacharyya_coefficient
from utils.nn_utils import softmax
import numpy as np
from utils.distances import color_similarity as color_similarity
import queue
import utils.nn_utils as utils
import resource
def evaluate():
lr = 0.0001
list_max_sizes = [10, 100, 1000, 2000, 3000, 5000, 10000]
data_path = 'data1'
model_path = 'data/alexnet_model_epoch_50_lr_001_2'
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root=data_path, train=False, download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=True, num_workers=0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
seed = 0
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
images = []
# base images for reference
for kk in range(10):
class_image_count = 0
class_images = []
for i, data in enumerate(trainloader, 0):
inputs, labels = data
for ll in range(len(labels)):
if labels[ll] == kk:
class_images.append(inputs[ll])
class_image_count += 1
if class_image_count == 100:
break
if class_image_count == 100:
images.append(class_images)
break
epsilon, delta = 0.8, 0.5 # compare Y
z_e, z_d = 10, 30 # compare Z
for var in range(len(list_max_sizes)):
model = utils.retrieve_model_AlexNet(model_path).to(device)
model.train()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr)
correct = 0
q, data_list = [], [] # in the long run, remove two lists, use only one data source
list_max_size = list_max_sizes[var]
new_counter, check_counter = 0, 0
testset = torchvision.datasets.CIFAR10(root=data_path, train=True, download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=0)
for i, data in enumerate(testloader, 0):
inputs, labels = data # inputs and labels are tensors with size 4x3x32x32 and 4
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
numpy_data = outputs.data.cpu().numpy()
_, predicted = torch.max(outputs.data, 1)
correct += (predicted == labels).sum().item()
print(i)
for index_in_batch in range(numpy_data.shape[0]):
if len(q) == list_max_size: # If the list is full, remove one or more
y = random.randint(0, 9) # generate a random number
reference_image = images[y][random.randint(0, 99)] # get a reference image for the selected class
label_array = np.zeros(10)
label_array[y] = 1
max_coeff = -1
min_chi = 500
queue_recent = queue.PriorityQueue()
queue_indices = queue.PriorityQueue()
queue_recent_1 = queue.PriorityQueue()
queue_indices_1 = queue.PriorityQueue()
for queue_index in range(list_max_size):
dist = bhattacharyya_coefficient(h1=label_array, h2=softmax(q[queue_index][:10]))
ex_d, ex_l = data_list[queue_index]
chi_dist = color_similarity(reference_image.data.numpy(), np.squeeze(ex_d.numpy()))
if dist > epsilon:
queue_recent.put((dist, ex_d))
queue_indices.put((dist, queue_index))
elif dist >= delta:
queue_recent_1.put((dist, ex_d))
queue_indices_1.put((dist, queue_index))
if dist > max_coeff:
max_coeff = dist
if min_chi > chi_dist:
min_chi = chi_dist
delta_batch_size, increase = 0, 0
done = False
no_update = True
temp_indices = []
if queue_recent.qsize() >= 4:
while queue_recent.qsize() > 4:
queue_recent.get()
queue_indices.get()
done = True
print('Queue size : ', queue_recent.qsize())
if queue_recent.qsize() > 0:
no_update = False
increase = queue_recent.qsize()
check_counter += increase
torch_data = torch.zeros((increase, 3, 32, 32))
torch_label = torch.zeros(increase)
tracking = 0
while not queue_recent.empty():
torch_data[tracking] = queue_recent.get()[1]
torch_label[tracking] = torch.tensor([y])
index = queue_indices.get()[1]
tracking = tracking + 1
temp_indices.append(index)
for g in optimizer.param_groups:
g['lr'] = lr / 10
optimizer.zero_grad()
inputs_new, labels_new = torch_data, torch_label.type(torch.LongTensor)
inputs_new, labels_new = inputs_new.to(device), labels_new.to(device)
# forward + backward + optimize
outputs_new = model(inputs_new)
loss = criterion(outputs_new, labels_new)
loss.backward()
optimizer.step()
if not done:
delta_batch_size = 4 - increase
if delta_batch_size > 0 and queue_recent_1.qsize() > 0:
if queue_recent_1.qsize() <= delta_batch_size:
delta_batch_size = queue_recent_1.qsize()
else :
while queue_recent_1.qsize() > delta_batch_size:
queue_recent_1.get()
queue_indices_1.get()
no_update = False
check_counter += delta_batch_size
torch_data = torch.zeros((delta_batch_size, 3, 32, 32))
torch_label = torch.zeros(delta_batch_size)
tracking = 0
while not queue_recent_1.empty():
torch_data[tracking] = queue_recent_1.get()[1]
torch_label[tracking] = torch.tensor([y])
index = queue_indices_1.get()[1]
tracking = tracking + 1
temp_indices.append(index)
for g in optimizer.param_groups:
g['lr'] = lr / 2
optimizer.zero_grad()
inputs_new, labels_new = torch_data, torch_label.type(torch.LongTensor)
inputs_new, labels_new = inputs_new.to(device), labels_new.to(device)
# forward + backward + optimize
outputs_new = model(inputs_new)
loss = criterion(outputs_new, labels_new)
loss.backward()
optimizer.step()
if no_update:
del q[0]
del data_list[0]
check_counter += 1
else :
if len(temp_indices) > 0:
for index in sorted(temp_indices, reverse=True):
del q[index]
del data_list[index]
# Add the current element to the list
conf_values = numpy_data[index_in_batch]
label_array_1 = np.zeros(len(conf_values))
label_array_1[labels.data.cpu().numpy()[index_in_batch]] = 1
predicted_actual = np.concatenate([conf_values, label_array_1])
q.append(predicted_actual)
data_list.append(data)
print('Number of correct classfications : ', correct)
print('Total number of classifications : ', check_counter)
print('Accuracy : ', correct * 100 / check_counter)
for final_index in range(len(q)):
inputs, labels = data_list[final_index] # inputs and labels are tensors with size 4x3x32x32 and 4
outputs = model.forward(inputs)
_ , predicted = torch.max(outputs.data, 1)
correct += (predicted == labels).sum().item()
new_counter += (predicted == labels).sum().item()
check_counter += 1
print('Number of correct classfications : ', correct)
print('Total number of classifications : ', check_counter)
print('Accuracy : ', correct * 100 / check_counter)
if __name__ == '__main__':
torch.multiprocessing.freeze_support()
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (2048*20, rlimit[1]))
evaluate()
Thanks