Cross validation in Pytorch

Hi,
I need some help to do cross validation for my code. I am implementing federated learning for cancer prediction. But don’t know to how to implement cross validation in pytorch. Here is my code

federated_train_loader = sy.FederatedDataLoader(train_data.federate((hospital_1, hospital_2)), batch_size=args.batch_size, shuffle=True)
dataloaders['train'] = federated_train_loader
def train_model_federated(model, criterion, optimizer, scheduler, num_epochs=10):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{} at {}'.format(epoch, num_epochs - 1, datetime.now(my_timezone).strftime('%I:%M:%S %p (%d %b %Y)')))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'valid']:
            if phase == 'train':
                scheduler.step()
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode
            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                if phase == 'valid':
                    inputs = inputs.to(device)
                    labels = labels.to(device)
                else:
                    inputs = inputs.to(device).get()
                    labels = labels.to(device).get()
                
                # zero the parameter gradients
                optimizer.zero_grad()
                #print("Enter 2st loop")
                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
#                         scheduler.step()


                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'valid' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

You can use Subset to create different folds for cross-validation by providing data/train/validation indices.

class Subset(Dataset):
    """
    Subset of a dataset at specified indices.
    Arguments:
        dataset (Dataset): The whole Dataset
        indices (sequence): Indices in the whole set selected for subset
    """
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = indices

    def __len__(self):
        if self.indices.shape == ():
            print('this happens: Subset')
            return 1
        else:
            return len(self.indices)

    def __getitem__(self, idx):
        return self.dataset[self.indices[idx]]

for i in range(k):
    print('Processing fold: ', i + 1)
    """%%%% Initiate new model %%%%""" #in every fold
    valid_idx = np.arange(len(dataset))[i * num_val_samples:(i + 1) * num_val_samples]
    train_idx = np.concatenate([np.arange(len(dataset))[:i * num_val_samples], np.arange(len(dataset))[(i + 1) * num_val_samples:]], axis=0)
    train_dataset = Subset(dataset, train_idx)
    valid_dataset = Subset(dataset, valid_idx)
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=1)
    valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=1)

Though I don’t know your data structure, the codes look ok to me.

Hi thank you for the reply but the code is showing error

AttributeError: ‘Subset’ object has no attribute ‘federate’

Here is my code

def train_model_kfold(model, criterion, optimizer, scheduler, num_epochs=5):

model_wts = copy.deepcopy(model.state_dict())

# total_set  = datasets.ImageFolder(data_dir)

splits = KFold(n_splits = 5, shuffle = True, random_state = 42)

for fold, (train_idx, valid_idx) in enumerate(splits.split(total_set)):

    print('Fold : {}'.format(fold))

    dataset_train = Subset(total_set, train_idx)

    dataset_valid = Subset(total_set, valid_idx)

    federated_train_loader = sy.FederatedDataLoader(dataset_train.federate((hospital_1, hospital_2)), batch_size=32)

    trainloader = torch.utils.data.DataLoader(dataset_train, batch_size=32, shuffle =True)

    valid_loader = torch.utils.data.DataLoader(dataset_valid, batch_size=32, shuffle=True)

train_sampler = SubsetRandomSampler(train_idx)

valid_sampler = SubsetRandomSampler(valid_idx)

# train_loader = torch.utils.data.DataLoader(

# total_set,

# batch_size=32, sampler=train_sampler)

federated_train_loader = sy.FederatedDataLoader(total_set.federate((hospital_1, hospital_2)), batch_size=32,sampler=train_sampler)

valid_loader = torch.utils.data.DataLoader(

total_set,

batch_size=32, sampler=valid_sampler)

    model.load_state_dict(model_wts)

    for epoch in range(num_epochs):  

        model.train()

        running_loss = 0.0

        running_corrects = 0

        trunning_corrects = 0

        for inputs, labels in federated_train_loader:

            inputs = inputs.to(device).get()

            labels = labels.to(device).get()

            optimizer.zero_grad()

            with torch.set_grad_enabled(True):

                outputs = model(inputs)

                _, preds = torch.max(outputs, 1)

                loss = criterion(outputs, labels)

                loss.backward()

                optimizer.step()

            running_loss += loss.item() * inputs.size(0)

            running_corrects += (preds == labels).sum()

            trunning_corrects += preds.size(0)

            # scheduler.step()

        epoch_loss = running_loss / trunning_corrects

        epoch_acc = (running_corrects.double()*100) / trunning_corrects

        print('\t\t Training: Epoch({}) - Loss: {:.4f}, Acc: {:.4f}'.format(epoch, epoch_loss, epoch_acc))

        model.eval()   

        vrunning_loss = 0.0

        vrunning_corrects = 0

        num_samples = 0

        for data, labels in valid_loader:

            data = data.to(device)

            labels = labels.to(device)

            optimizer.zero_grad()

            with torch.no_grad():

                outputs = model(data)

                _, preds = torch.max(outputs, 1)

                loss = criterion(outputs, labels)

            vrunning_loss += loss.item() * data.size(0)

            vrunning_corrects += (preds == labels).sum()

            num_samples += preds.size(0)

        vepoch_loss = vrunning_loss/num_samples

        vepoch_acc = (vrunning_corrects.double() * 100)/num_samples

        print('\t\t Validation({}) - Loss: {:.4f}, Acc: {:.4f}'.format(epoch, vepoch_loss, vepoch_acc))

hi, could you reformat the code in markdown? They are hard to read.

Hi,
My apologies. Here is the code

def train_model_kfold(model, criterion, optimizer, scheduler, num_epochs=5):
model_wts = copy.deepcopy(model.state_dict())

# total_set  = datasets.ImageFolder(data_dir)

splits = KFold(n_splits = 5, shuffle = True, random_state = 123)

for fold, (train_idx, valid_idx) in enumerate(splits.split(total_set)):

    print('Fold : {}'.format(fold))

    dataset_train = Subset(total_set, train_idx)

    dataset_valid = Subset(total_set, valid_idx)

    federated_train_loader = sy.FederatedDataLoader(dataset_train.federate((hospital_1, hospital_2)), batch_size=32)

    valid_loader = torch.utils.data.DataLoader(dataset_valid, batch_size=32, shuffle=True)

   

    model.load_state_dict(model_wts)

    for epoch in range(num_epochs):  

        model.train()

        running_loss = 0.0

        running_corrects = 0

        trunning_corrects = 0

        for inputs, labels in federated_train_loader:

            inputs = inputs.to(device).get()

            labels = labels.to(device).get()

            optimizer.zero_grad()

            with torch.set_grad_enabled(True):

                outputs = model(inputs)

                _, preds = torch.max(outputs, 1)

                loss = criterion(outputs, labels)

                loss.backward()

                optimizer.step()

            running_loss += loss.item() * inputs.size(0)

            running_corrects += (preds == labels).sum()

            trunning_corrects += preds.size(0)

            # scheduler.step()

        epoch_loss = running_loss / trunning_corrects

        epoch_acc = (running_corrects.double()*100) / trunning_corrects

        print('\t\t Training: Epoch({}) - Loss: {:.4f}, Acc: {:.4f}'.format(epoch, epoch_loss, epoch_acc))

        model.eval()   

        vrunning_loss = 0.0

        vrunning_corrects = 0

        num_samples = 0

        for data, labels in valid_loader:

            data = data.to(device)

            labels = labels.to(device)

            optimizer.zero_grad()

            with torch.no_grad():

                outputs = model(data)

                _, preds = torch.max(outputs, 1)

                loss = criterion(outputs, labels)

            vrunning_loss += loss.item() * data.size(0)

            vrunning_corrects += (preds == labels).sum()

            num_samples += preds.size(0)

        vepoch_loss = vrunning_loss/num_samples

        vepoch_acc = (vrunning_corrects.double() * 100)/num_samples

        print('\t\t Validation({}) - Loss: {:.4f}, Acc: {:.4f}'.format(epoch, vepoch_loss, vepoch_acc))

Hi @geetu were you able to use cross validation in pytorch?