Hello all,
I am a PyTorch beginner and was working on Image Classification on MNIST dataset. I have the following code which works, but I wanted to have a code review to get someone else’s opinions / insights and ways to improve.
Apologies if this breaks any rules.
import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import torch.nn as nn
from torch.optim import SGD
from sklearn.metrics import accuracy_score
mnist_train = torchvision.datasets.MNIST( '/content/', train = True, download = True, transform = ToTensor() )
mnist_test = torchvision.datasets.MNIST( '/content/', train = False, download = True, transform = ToTensor() )
mnist_train_loader = DataLoader( mnist_train )
mnist_test_loader = DataLoader( mnist_test )
loss_function = nn.CrossEntropyLoss()
optimizer = SGD( params = model.parameters(), lr = 0.01 )
## Training Loop
n_epochs = 2
device = torch.device( 'cuda:0' if torch.cuda.is_available() else 'cpu' )
print( device )
for epoch in range( n_epochs ):
model.train()
model.to( device )
for index, data in enumerate( mnist_train_loader ):
images, labels = data
images = images.to( device )
labels = labels.to( device )
preds = model( images )
current_loss = loss_function( preds, labels )
optimizer.zero_grad()
current_loss.backward()
optimizer.step()
## Evaluation
pred_list = []
label_list = []
with torch.no_grad():
model.to( 'cpu' )
for images, labels in mnist_test_loader:
images = images.to( 'cpu' )
labels = labels.to( 'cpu' )
preds = model( images )
pred_list.append( preds.argmax().unsqueeze_( 0 ) )
label_list.append( labels )
accuracy = accuracy_score( label_list, pred_list )
print( f'The model\'s accuracy is { accuracy * 100 } % ' )
## Model
class CNNClassificationModel( nn.Module ):
def __init__( self ):
super( CNNClassificationModel, self ).__init__()
self.conv1 = nn.Conv2d( 1, 4, 5 )
self.conv2 = nn.Conv2d( 4, 8, 5 )
self.conv3 = nn.Conv2d( 8, 16, 5 )
self.fc1 = nn.Linear( 4096, 2048 )
self.fc2 = nn.Linear( 2048, 1024 )
self.fc3 = nn.Linear( 1024, 512 )
self.fc4 = nn.Linear( 512, 256 )
self.fc5 = nn.Linear( 256, 128 )
self.fc6 = nn.Linear( 128, 10 )
self.relu = nn.ReLU()
self.softmax = nn.Softmax()
def forward( self, image ):
x = self.relu( self.conv1( image ) )
x = self.relu( self.conv2( x ) )
x = self.relu( self.conv3( x ) )
x = x.view( -1, 4096 )
x = self.relu( self.fc1( x ) )
x = self.relu( self.fc2( x ) )
x = self.relu( self.fc3( x ) )
x = self.relu( self.fc4( x ) )
x = self.relu( self.fc5( x ) )
x = self.relu( self.fc6( x ) )
return x
model = CNNClassificationModel()