I need help in this K-Fold Cross validation implementation

I have implemented a feed forward neural network in PyTorch to classify image dataset using K-fold cross val. I have some problems during training. For every fold, the accuracy and loss of the validation is better than the training. I checked with different dataset, it is still the same. I am fine-tuning Vgg16. Any tips on how this could happen?

total_set  = datasets.ImageFolder(root_dir)
splits = KFold(n_splits = 5, shuffle = True, random_state = 42)
.............
for fold, (train_idx, valid_idx) in enumerate(splits.split(total_set)):
    print('Fold : {}'.format(fold))
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)
    train_loader = torch.utils.data.DataLoader(
                      WrapperDataset(total_set,  transform=transforms['train']), 
                      batch_size=64, sampler=train_sampler)
    valid_loader = torch.utils.data.DataLoader(
                      WrapperDataset(total_set, transform = transforms['valid']),
                      batch_size=64, sampler=valid_sampler)
    model.load_state_dict(model_wts)

    for epoch in range(num_epochs):      
        model.train()
        running_loss = 0.0
        running_corrects = 0
        trunning_corrects = 0
        for inputs, labels in train_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            with torch.set_grad_enabled(True):
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
            running_loss += loss.item() * inputs.size(0)
            running_corrects += (preds == labels).sum()
            trunning_corrects += preds.size(0)
            scheduler.step()

        epoch_loss = running_loss / trunning_corrects
        epoch_acc = (running_corrects.double()*100) / trunning_corrects
        print('\t\t Training: Epoch({}) - Loss: {:.4f}, Acc: {:.4f}'.format(epoch, epoch_loss, epoch_acc))

        model.eval()   
        vrunning_loss = 0.0
        vrunning_corrects = 0
        num_samples = 0
        for data, labels in valid_loader:
            data = data.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            with torch.no_grad():
                outputs = model(data)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)
            vrunning_loss += loss.item() * data.size(0)
            vrunning_corrects += (preds == labels).sum()
            num_samples += preds.size(0)
        vepoch_loss = vrunning_loss/num_samples
        vepoch_acc = (vrunning_corrects.double() * 100)/num_samples
        print('\t\t Validation({}) - Loss: {:.4f}, Acc: {:.4f}'.format(epoch, vepoch_loss, vepoch_acc))

The wrapper class applies required transformation to the ImageFolder dataset.

class WrapperDataset:
    def __init__(self, dataset, transform=None, target_transform=None):
        self.dataset = dataset
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        image, label = self.dataset[index]
        if self.transform is not None:
            #image = transforms.ToPILImage()(image)
            image = self.transform(image)
        if self.target_transform is not None:
            label = self.target_transform(label)
        return image, label

    def __len__(self):
        return len(self.dataset)

When I try running this cross validation on pre-trained Vgg16, I notice the following output, which is almost the same for all other folds. I don’t really understand why the network is not stable, and why the model is performing better on the validation dataset.

---> Fold : 0
Epoch : 0 Training:  Loss: 0.3931, Acc: 85.0000 : Validation(0) - Loss: 0.3516, Acc: 93.6556
Epoch : 1 Training:  Loss: 0.4071, Acc: 85.7576 : Validation(1) - Loss: 0.3743, Acc: 90.0302
Epoch : 2 Training:  Loss: 0.3980, Acc: 87.7273 : Validation(2) - Loss: 0.3736, Acc: 90.0302
Epoch : 3 Training:  Loss: 0.4017, Acc: 84.6970 : Validation(3) - Loss: 0.3427, Acc: 92.4471
.............
Epoch : 6 Training:  Loss: 0.4086, Acc: 85.6061 : Validation(6) - Loss: 0.3704, Acc: 90.3323
.............

Sorry for the long post, I wanted to come clear with my long week trouble to solve this issue. Any suggestions please?

1 Like

How did you create the model_wts? Note that the state_dict will store references to the parameters, which would thus also get updated, if you didn’t use copy.deepcopy(model.state_dict()).
In that case, you would reload the current state_dict and would thus just continue with the model training.

1 Like

model_wts is a copy of the initial weights from Vgg16 pre-trained model. I used the load_state_dict to reload the initial weights in each new fold. The structure of the code is as follows:

def train_model(model, criterion, optimizer, scheduler, num_epochs=5):
     model_wts = copy.deepcopy(model.state_dict())
     .....
   for fold, (train_idx, valid_idx) in enumerate(splits.split(total_set)):
         print('Fold : {}'.format(fold))
         train_sampler = SubsetRandomSampler(train_idx)
         ........
         model.load_state_dict(model_wts)
         for epoch in range(num_epochs):
             # The same code
model_ft = models.vgg16(pretrained=True)
num_ftrs = model_ft.classifier[6].in_features
model_ft.classifier[6] = nn.Linear(num_ftrs, 2)
model_ft = model_ft.to(device)
criterion = nn.CrossEntropyLoss()
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=10, gamma=0.1)
train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=5)

During training, either the validation accuracy is better than the training or sometimes it is constant.

That might be expected e.g. if you are using dropout during training.
Since the training model will have less capacity its loss might be higher than the validation loss, which would use the complete model.

1 Like