I’ve written a snippet to classify Omniglot images. I calculate the training and validation losses in each epoch, where the latter is computed using images that were not seen by the network before. The two plots are as below:
Since the training loss decreases while the validation loss increases, I have concluded that my model overfits. I’ve tried several suggestions (e.g. here) to overcome this, including:
- Increasing the size of the training set.
- shuffling the data.
- Adding dropout layers (up to p=0.9).
- Using smaller model.
- Altering the architecture.
- Changing the learning rate.
- Reducing the batch size.
- Adding weight decay.
However, the validation loss still increases. I wonder if there are any other suggestions to improve this behavior or if this is not overfitting, but the problem is something else. Below is the snippet used in this question.
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn, optim
from torch.utils.data import DataLoader
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
dim_out = 964
# -- embedding params
self.cn1 = nn.Conv2d(1, 16, 7)
self.cn2 = nn.Conv2d(16, 32, 4)
self.cn3 = nn.Conv2d(32, 64, 3)
self.pool = nn.MaxPool2d(2)
self.bn1 = nn.BatchNorm2d(16)
self.bn2 = nn.BatchNorm2d(32)
self.bn3 = nn.BatchNorm2d(64)
# -- prediction params
self.fc1 = nn.Linear(256, 170)
self.fc2 = nn.Linear(170, 50)
self.fc3 = nn.Linear(50, dim_out)
# -- non-linearity
self.relu = nn.ReLU()
self.Beta = 10
self.sopl = nn.Softplus(beta=self.Beta)
def forward(self, x):
y1 = self.pool(self.bn1(self.relu(self.cn1(x))))
y2 = self.pool(self.bn2(self.relu(self.cn2(y1))))
y3 = self.relu(self.bn3(self.cn3(y2)))
y3 = y3.view(y3.size(0), -1)
y5 = self.sopl(self.fc1(y3))
y6 = self.sopl(self.fc2(y5))
return self.fc3(y6)
class Train:
def __init__(self):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# -- data
dim = 28
batch_size = 400
my_transforms = transforms.Compose([transforms.Resize((dim, dim)), transforms.ToTensor()])
trainset = torchvision.datasets.Omniglot(root="./data/omniglot_train/", download=False, transform=my_transforms)
validset = torchvision.datasets.Omniglot(root="./data/omniglot_train/", background=False, download=False,
transform=my_transforms)
self.TrainDataset = DataLoader(dataset=trainset, batch_size=batch_size, shuffle=True)
self.ValidDataset = DataLoader(dataset=validset, batch_size=len(validset), shuffle=False)
self.N_train = len(trainset)
self.N_valid = len(validset)
# -- model
self.model = MyModel().to(self.device)
# -- train
self.epochs = 3000
self.loss = nn.CrossEntropyLoss()
self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3)
def train_epoch(self):
self.model.train()
train_loss = 0
for batch_idx, data_batch in enumerate(self.TrainDataset):
# -- predict
predict = self.model(data_batch[0].to(self.device))
# -- loss
loss = self.loss(predict, data_batch[1].to(self.device))
# -- optimize
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
train_loss += loss.item()
return train_loss/(batch_idx+1)
def valid_epoch(self):
with torch.no_grad():
self.model.eval()
for data_batch in self.ValidDataset:
# -- predict
predict = self.model(data_batch[0].to(self.device))
# -- loss
loss = self.loss(predict, data_batch[1].to(self.device))
return loss.item()
def __call__(self):
for epoch in range(self.epochs):
train_loss = self.train_epoch()
valid_loss = self.valid_epoch()
print('Epoch {}: Training loss = {:.5f}, Validation loss = {:.5f}.'.format(epoch, train_loss, valid_loss))
torch.save(self.model.state_dict(), './model_stat.pth')
if __name__ == '__main__':
my_train = Train()
my_train()