I used the following code to implement my MNIST dataset learning. When L1&L2 regularization are not used, the test accuracy can reach 94%. When L2 is used while L1 is not used, accuracy can reach 96%. While the usage of L1 can drop the accuracy straight down to 11%. Is my implementation wrong? or L1 is just like that, cause there are so many parameters, whose sum is supposed to be hugely enough to influence the whole process.
import torch
import torchvision
from visdom import Visdom
from sklearn.model_selection import KFold
maxepoch=20
k_folds=5
batch_size=50
learning_rate=1e-2
lan_l1=0.01
lan_l2=0
device=torch.device('cuda:0')
train_data=torchvision.datasets.MNIST('../data', train=True, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))
]))
test_data=torchvision.datasets.MNIST('../data', train=False,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))
]))
class myMLP(torch.nn.Module):
def __init__(self):
super(myMLP, self).__init__()
self.model=torch.nn.Sequential(
torch.nn.Linear(784, 200),
torch.nn.Dropout(0.3), #drop 30%
torch.nn.LeakyReLU(inplace=True),
torch.nn.Linear(200, 200),
torch.nn.Dropout(0.4), #drop 40%
torch.nn.LeakyReLU(inplace=True),
torch.nn.Linear(200, 10),
)
def forward(self, x):
x=self.model(x)
return x
myNet=myMLP().to(device)
optimizer=torch.optim.SGD(myNet.parameters(), lr=learning_rate, weight_decay=lan_l2)
loss_function=torch.nn.CrossEntropyLoss().to(device)
viz=Visdom()
viz.line([0.], [0.], win='train_loss', opts=dict(title='Train Loss'))
viz.line([0.], [0.], win='val', opts=dict(title='Validation Accuracy'))
global_step=0
kfold = KFold(n_splits=k_folds, shuffle=True)
train_ids_set=[]
val_ids_set=[]
for t, v in kfold.split(train_data):
train_ids_set.append(t)
val_ids_set.append(v)
for epoch in range(maxepoch):
train_ids=train_ids_set[epoch%k_folds]
val_ids = val_ids_set[epoch%k_folds]
train_subsampler=torch.utils.data.SubsetRandomSampler(train_ids)
val_subsampler =torch.utils.data.SubsetRandomSampler(val_ids)
train_loader=torch.utils.data.DataLoader(train_data, batch_size=batch_size, sampler=train_subsampler)
val_loader =torch.utils.data.DataLoader(train_data, batch_size=batch_size, sampler= val_subsampler)
myNet.train()
for batch_idx, (data, target) in enumerate(train_loader):
data=data.view(-1, 28*28)
data, target=data.to(device), target.to(device)
logits=myNet(data)
loss=loss_function(logits, target)
loss_l1=0
for parm in myNet.parameters():
loss_l1+=torch.sum(torch.abs(parm))
loss+=lan_l1*loss_l1
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (batch_idx+1)%60==0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tAverage Loss: {:.6f}'.format(epoch, (batch_idx+1)*len(data), batch_size*len(train_loader), 100*batch_idx/len(train_loader), loss.item()/batch_size))
global_step+=1
viz.line([loss.item()], [global_step], win='train_loss', update='append')
myNet.eval()
val_loss=0
correct=0
for data, target in val_loader:
data=data.view(-1, 28*28)
data, target=data.to(device), target.to(device)
logits=myNet(data)
val_loss+=loss_function(logits, target).item()
pred=logits.data.argmax(dim=1)
correct+=pred.eq(target.data).sum()
viz.images(data.view(-1, 1, 28, 28).clamp(0, 1), win='pics', opts=dict(title='Handwirtting'))
viz.text(str(pred), win='pred', opts=dict(title='Predicted'))
val_loss/=(batch_size*len(val_loader))
print('\nValidation set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(val_loss, correct, (batch_size*len(val_loader)), 100*correct/(batch_size*len(val_loader))))
viz.line([(correct/(batch_size*len(val_loader))).cpu().numpy()], [epoch], win='val', update='append')
myNet.eval()
test_loader=torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True)
test_loss=0
correct=0
for data, target in test_loader:
data=data.view(-1, 28*28)
data, target=data.to(device), target.to(device)
logits=myNet(data)
test_loss+=loss_function(logits, target).item()
pred=logits.data.argmax(dim=1)
correct+=pred.eq(target.data).sum()
viz.images(data.view(-1, 1, 28, 28).clamp(0, 1), win='pics', opts=dict(title='Handwirtting'))
viz.text(str(pred), win='pred', opts=dict(title='Predicted'))
test_loss/=len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset), 100*correct/len(test_loader.dataset)))
Thanks.