I need help in implementing the Gaussian Model Classifier on MNIST dataset (http://www-inst.eecs.berkeley.edu/~cs70/sp15/notes/n21.pdf). So far, I did is the following I have calculated the mean, variance of each single-digit from 0 to 9. But I am confused about what to do after that?
Can anybody guide me or provide a solution? [I ant to calculate the covariance matrix along with the mean,variance and then I want to calculate the MAP estimator ]
Any help would be really great.
def online_mean_and_sd(trainLoader):
"""Compute the mean and sd in an online fashion
Var[x] = E[X^2] - E^2[X]
"""
mean = 0.
std = 0.
nb_samples = 0.
for data, _ in trainLoader:
batch_samples = data.size(0)
data = data.view(batch_samples, data.size(1), -1)
mean += data.mean(2).sum(0)
std += data.std(2).sum(0)
nb_samples += batch_samples
mean /= nb_samples
std /= nb_samples
return mean,std
trainset=datasets.MNIST(root='/home/jmandivarapu1/datasets', train=True,download=True, transform=transform)
testset = datasets.MNIST(root='/home/jmandivarapu1/datasets', train=False,download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,shuffle=True,**kwargs)
test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size,shuffle=True, **kwargs)
global_trainset_train_labels=trainset.targets
global_trainset_train_data=trainset.data
global_testset_test_labels=testset.targets
global_testset_test_data=testset.data
global_train_loader=train_loader
global_test_loader=test_loader
######################################################################################################
#MODIFYING DATSET FOR INDIVIDUAL CLASSES
######################################################################################################
def load_individual_class(postive_class,negative_classes):
RELOAD_DATASET()
global train_loader,test_loader,train_Indexs,test_Indexs
index_train_postive=[]
index_test_postive=[]
index_train_negative=[]
index_test_negative=[]
#print(postive_class)
for i in range(0,len(postive_class)):
index_train_postive=index_train_postive+train_Indexs[postive_class[i]]
index_test_postive=index_test_postive+test_Indexs[postive_class[i]]
for i in range(0,len(negative_classes)):
index_train_negative=index_train_negative+train_Indexs[negative_classes[i]][0:int(1*(len(train_Indexs[negative_classes[i]])))]
index_test_negative=index_test_negative+test_Indexs[negative_classes[i]][0:int(1*(len(test_Indexs[negative_classes[i]])))]
index_train=index_train_postive+index_train_negative
index_test=index_test_postive+index_test_negative
modified_train_labels = [1 if (trainset.train_labels[x] in postive_class) else 0 for x in index_train]
modified_test_labels = [1 if (testset.test_labels[x] in postive_class) else 0 for x in index_test]
trainset.targets=modified_train_labels#train set labels
trainset.data=trainset.train_data[index_train]#train set data
testset.targets=modified_test_labels#testset labels
testset.data=testset.test_data[index_test]#testset data
train_loader = torch.utils.data.DataLoader(trainset, batch_size=32,shuffle=True)
test_loader = torch.utils.data.DataLoader(testset, batch_size=32,shuffle=True)
train_loader_highbatch = torch.utils.data.DataLoader(trainset, batch_size=1000,shuffle=True)
return train_loader,test_loader,train_loader_highbatch
mnistSkills=[0,1,2,3,4,5,6,7,8,9]
mnistSkillsMeanVar=[]
#####################################################################################################
# MNSIT TRAINING
######################################################################################################
for mnsit_class in range(0,len(mnistSkills)):
task_samples=[]
train_skills=mnistSkills[mnsit_class]
print("Skills are ",train_skills)
model = models.Net().to(device)
#model=nn.DataParallel(model)
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
Train_loader,Test_loader,train_loader_highbatch=load_individual_class([mnsit_class],[])
mean,std= online_mean_and_sd(train_loader_highbatch)
print('SKILL',mnsit_class,mean,SD)
mnistSkillsMeanVar.append([{str(mnsit_class):{'mean':mean,'var':SD}}])