I’ve adapted the code I’m using for a different project, to make it runnable with the MNIST dataset.
The idea of the code is extremely simple:
- Use a conv net to predict the rotation angle in a mnist fashion dataset (publicly available, and download through the Pytorch API.)
The problem is also very straight forward, the neural net ends up predicting always the same label regardless of the input!
I have spend several days already and thought it’s time to get some help or opinions.
This is the code:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import torchvision
from torchvision import datasets
from torchvision.transforms import ToTensor
# %%
torchvision.disable_beta_transforms_warning()
import torchvision.transforms.v2 as transforms
from torchvision.io import read_image
# %%
class SpotRotBackbone(nn.Module):
"""
Backbone (first part) of the spot rotated neural network
It is just a set of convolution.
Important: The next layer expects a flattened tensor.
Args: None
"""
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1)
self.relu1 = nn.ReLU()
self.dropout1 = nn.Dropout(0.3)
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
self.relu2 = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=(2, 2))
self.flatten = nn.Flatten()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv1(x)
x = self.relu1(x)
x = self.dropout1(x)
x = self.conv2(x)
x = self.relu2(x)
x = self.maxpool(x)
x = self.flatten(x)
return x
# %%
class SpotRotHead(nn.Module):
"""
Classification of the image rotation angle.
Args:
input_size: the size of the input image
"""
def __init__(self, input_size: int):
super().__init__()
self.fc1 = nn.Linear(input_size, 32)
self.relu1 = nn.ReLU()
self.dropout1 = nn.Dropout(0.5)
self.fc2 = nn.Linear(32, 4)
# self.softmax = nn.Softmax(dim=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc1(x)
x = self.relu1(x)
x = self.dropout1(x)
x = self.fc2(x)
# x = self.softmax(x)
return x
# %%
class NeuralNetwork(nn.Module):
def __init__(self, input_size: int):
super().__init__()
self.backbone = SpotRotBackbone()
output_dimensions = self.backbone.forward(
torch.zeros(1, 1, input_size, input_size)
).shape
self.head = SpotRotHead(output_dimensions[-1])
def forward(self, x):
x = self.backbone(x)
x = self.head(x)
return x
# from torchinfo import summary
# summary(NeuralNetwork(28), (1, 1, 28, 28))
device = "cuda" if torch.has_cuda else "mps" if torch.has_mps else "cpu"
print(device)
model = NeuralNetwork(28).to(device)
# this is a little hack (hopefully correct) to introduce my own labels from random rotations to the images !
rand_rots = []
def myTransform(image):
image = ToTensor()(image)
rand_rot = int(torch.randint(4, size=(1,)).item())
angle = rand_rot * 90
image = transforms.functional.rotate(image, angle)
rand_rots.append(rand_rot)
return image
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
# transform=ToTensor(),
transform=myTransform,
)
# Download test data from open datasets.
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
# transform=ToTensor(),
transform=myTransform,
)
train_dataloader = DataLoader(training_data, batch_size=16)
test_dataloader = DataLoader(test_data, batch_size=16)
def train_loop(dataloader, model: NeuralNetwork, loss_fn, optimizer):
size = len(dataloader.dataset) # type: ignore
# Set the model to training mode - important for batch normalization and dropout layers
model.train()
rand_rots.clear()
for batch, (X, y) in enumerate(dataloader):
X = X.to(device)
y = rand_rots[batch * 16 : batch * 16 + len(X)]
y = torch.Tensor(y).to(torch.int64).to(device)
assert isinstance(y, torch.Tensor)
assert isinstance(X, torch.Tensor)
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
loss.backward()
optimizer.step()
optimizer.zero_grad()
if batch % 100 == 0:
loss, current = loss.item(), (batch + 1) * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
def test_loop(dataloader, model: NeuralNetwork, loss_fn):
model.eval()
size = len(dataloader.dataset)
num_batches = len(dataloader)
test_loss, correct = 0, 0
y_max = None
y = None
with torch.no_grad():
rand_rots.clear()
for batch, (X, y) in enumerate(dataloader):
X = X.to(device)
y = rand_rots[batch * 16 : batch * 16 + len(X)]
y = torch.Tensor(y).to(torch.int64).to(device)
pred = model(X)
y_max = pred.argmax(1)
print(y, pred, y_max)
assert isinstance(pred, torch.Tensor)
test_loss += loss_fn(pred, y).item()
correct += (y_max == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(
f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n Last y: {y} \n Last pred: {y_max} \n"
)
# %%
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
epochs = 100
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train_loop(train_dataloader, model, loss_fn, optimizer)
test_loop(test_dataloader, model, loss_fn)
print("Done!")
The last predictions all look like this:
Accuracy: 25.4%, Avg loss: 1.386237
Last y: tensor([3, 1, 1, 2, 3, 3, 2, 0, 2, 2, 0, 3, 0, 0, 1, 0], device='mps:0')
Last pred: tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='mps:0')
Can you help me find out what am I doing wrong, possibly related to the labels or the conv net architecture ?