from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data.sampler import*
import numpy as np
import math
import copy
import random
torch.backends.cudnn.deterministic = True
width = 1024
class LINEAR(nn.Module):
def __init__(self):
super(LINEAR, self).__init__()
self.weight1 = torch.nn.Parameter(torch.ones(784,width))
self.bias1 = torch.nn.Parameter(torch.ones(1, width))
self.weight2 = torch.nn.Parameter(torch.ones(width, 1))
self.bias2 = torch.nn.Parameter(torch.ones(1, 1))
torch.nn.init.normal_(self.weight1, 0, 1/math.sqrt(784))
torch.nn.init.normal_(self.bias1, 0, 1)
torch.nn.init.normal_(self.weight2, 0, 1/math.sqrt(width))
torch.nn.init.normal_(self.bias2, 0, 1)
def forward(self, x):
x = x.view(-1, 784)
x = torch.mm(x,self.weight1)+self.bias1
x = F.relu(x)
x = torch.mm(x,self.weight2)
return x
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.mse_loss(output, target.float())
output.backward((output-target.float())/args.batch_size)
optimizer.step()
pred = (output > 0.5).long()
correct = pred.eq(target.view_as(pred)).sum().item()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f},\tAccuracy: {}/{} ({:.0f}%)'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item(), correct, args.batch_size,
100. * correct / args.batch_size))
accuracy += correct / args.batch_size
def main():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=200, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=10, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
help='SGD momentum (default: 0.5)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--ns', type=float, default=0.1, metavar='ns',
help='noise rate (default: 0.1)')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=5, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--save-model', action='store_true', default=False,
help='For Saving the current Model')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda:0")
torch.manual_seed(args.seed)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
idx1 = mnist_trainset.targets == 0
idx2 = mnist_trainset.targets == 1
idx = idx1+idx2
mnist_trainset.targets = mnist_trainset.targets[idx]
mnist_trainset.data = mnist_trainset.data[idx]
train_loader = torch.utils.data.DataLoader(
mnist_trainset,
batch_size=args.batch_size, shuffle= True, **kwargs)
model = LINEAR().to(device)
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0)
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
if __name__ == '__main__':
main()
I tried a simple 2-layer neural network with 2-class MNIST dataset. But it doesn’t converge. Can anyone tell me why?
By the way, when I use the one-hot version, it converges.