geetu
(Geetu)
May 4, 2021, 7:19am
1
Hi,
I need some help to do cross validation for my code. I am implementing federated learning for cancer prediction. But don’t know to how to implement cross validation in pytorch. Here is my code
federated_train_loader = sy.FederatedDataLoader(train_data.federate((hospital_1, hospital_2)), batch_size=args.batch_size, shuffle=True)
dataloaders['train'] = federated_train_loader
def train_model_federated(model, criterion, optimizer, scheduler, num_epochs=10):
since = time.time()
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0
for epoch in range(num_epochs):
print('Epoch {}/{} at {}'.format(epoch, num_epochs - 1, datetime.now(my_timezone).strftime('%I:%M:%S %p (%d %b %Y)')))
print('-' * 10)
# Each epoch has a training and validation phase
for phase in ['train', 'valid']:
if phase == 'train':
scheduler.step()
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
running_loss = 0.0
running_corrects = 0
# Iterate over data.
for inputs, labels in dataloaders[phase]:
if phase == 'valid':
inputs = inputs.to(device)
labels = labels.to(device)
else:
inputs = inputs.to(device).get()
labels = labels.to(device).get()
# zero the parameter gradients
optimizer.zero_grad()
#print("Enter 2st loop")
# forward
# track history if only in train
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
# backward + optimize only if in training phase
if phase == 'train':
loss.backward()
optimizer.step()
# scheduler.step()
# statistics
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects.double() / dataset_sizes[phase]
print('{} Loss: {:.4f} Acc: {:.4f}'.format(
phase, epoch_loss, epoch_acc))
# deep copy the model
if phase == 'valid' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
print('Best val Acc: {:4f}'.format(best_acc))
# load best model weights
model.load_state_dict(best_model_wts)
return model
You can use Subset
to create different folds for cross-validation by providing data/train/validation indices.
class Subset(Dataset):
"""
Subset of a dataset at specified indices.
Arguments:
dataset (Dataset): The whole Dataset
indices (sequence): Indices in the whole set selected for subset
"""
def __init__(self, dataset, indices):
self.dataset = dataset
self.indices = indices
def __len__(self):
if self.indices.shape == ():
print('this happens: Subset')
return 1
else:
return len(self.indices)
def __getitem__(self, idx):
return self.dataset[self.indices[idx]]
for i in range(k):
print('Processing fold: ', i + 1)
"""%%%% Initiate new model %%%%""" #in every fold
valid_idx = np.arange(len(dataset))[i * num_val_samples:(i + 1) * num_val_samples]
train_idx = np.concatenate([np.arange(len(dataset))[:i * num_val_samples], np.arange(len(dataset))[(i + 1) * num_val_samples:]], axis=0)
train_dataset = Subset(dataset, train_idx)
valid_dataset = Subset(dataset, valid_idx)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=1)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=1)
Though I don’t know your data structure, the codes look ok to me.
geetu
(Geetu)
May 7, 2021, 4:31pm
3
banikr:
dataset
Hi thank you for the reply but the code is showing error
AttributeError: ‘Subset’ object has no attribute ‘federate’
Here is my code
def train_model_kfold(model, criterion, optimizer, scheduler, num_epochs=5):
model_wts = copy.deepcopy(model.state_dict())
# total_set = datasets.ImageFolder(data_dir)
splits = KFold(n_splits = 5, shuffle = True, random_state = 42)
for fold, (train_idx, valid_idx) in enumerate(splits.split(total_set)):
print('Fold : {}'.format(fold))
dataset_train = Subset(total_set, train_idx)
dataset_valid = Subset(total_set, valid_idx)
federated_train_loader = sy.FederatedDataLoader(dataset_train.federate((hospital_1, hospital_2)), batch_size=32)
trainloader = torch.utils.data.DataLoader(dataset_train, batch_size=32, shuffle =True)
valid_loader = torch.utils.data.DataLoader(dataset_valid, batch_size=32, shuffle=True)
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
# train_loader = torch.utils.data.DataLoader(
# total_set,
# batch_size=32, sampler=train_sampler)
federated_train_loader = sy.FederatedDataLoader(total_set.federate((hospital_1, hospital_2)), batch_size=32,sampler=train_sampler)
valid_loader = torch.utils.data.DataLoader(
total_set,
batch_size=32, sampler=valid_sampler)
model.load_state_dict(model_wts)
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
running_corrects = 0
trunning_corrects = 0
for inputs, labels in federated_train_loader:
inputs = inputs.to(device).get()
labels = labels.to(device).get()
optimizer.zero_grad()
with torch.set_grad_enabled(True):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
running_corrects += (preds == labels).sum()
trunning_corrects += preds.size(0)
# scheduler.step()
epoch_loss = running_loss / trunning_corrects
epoch_acc = (running_corrects.double()*100) / trunning_corrects
print('\t\t Training: Epoch({}) - Loss: {:.4f}, Acc: {:.4f}'.format(epoch, epoch_loss, epoch_acc))
model.eval()
vrunning_loss = 0.0
vrunning_corrects = 0
num_samples = 0
for data, labels in valid_loader:
data = data.to(device)
labels = labels.to(device)
optimizer.zero_grad()
with torch.no_grad():
outputs = model(data)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
vrunning_loss += loss.item() * data.size(0)
vrunning_corrects += (preds == labels).sum()
num_samples += preds.size(0)
vepoch_loss = vrunning_loss/num_samples
vepoch_acc = (vrunning_corrects.double() * 100)/num_samples
print('\t\t Validation({}) - Loss: {:.4f}, Acc: {:.4f}'.format(epoch, vepoch_loss, vepoch_acc))
hi, could you reformat the code in markdown? They are hard to read.
geetu
(Geetu)
May 8, 2021, 1:43am
5
Hi,
My apologies. Here is the code
def train_model_kfold(model, criterion, optimizer, scheduler, num_epochs=5):
model_wts = copy.deepcopy(model.state_dict())
# total_set = datasets.ImageFolder(data_dir)
splits = KFold(n_splits = 5, shuffle = True, random_state = 123)
for fold, (train_idx, valid_idx) in enumerate(splits.split(total_set)):
print('Fold : {}'.format(fold))
dataset_train = Subset(total_set, train_idx)
dataset_valid = Subset(total_set, valid_idx)
federated_train_loader = sy.FederatedDataLoader(dataset_train.federate((hospital_1, hospital_2)), batch_size=32)
valid_loader = torch.utils.data.DataLoader(dataset_valid, batch_size=32, shuffle=True)
model.load_state_dict(model_wts)
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
running_corrects = 0
trunning_corrects = 0
for inputs, labels in federated_train_loader:
inputs = inputs.to(device).get()
labels = labels.to(device).get()
optimizer.zero_grad()
with torch.set_grad_enabled(True):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
running_corrects += (preds == labels).sum()
trunning_corrects += preds.size(0)
# scheduler.step()
epoch_loss = running_loss / trunning_corrects
epoch_acc = (running_corrects.double()*100) / trunning_corrects
print('\t\t Training: Epoch({}) - Loss: {:.4f}, Acc: {:.4f}'.format(epoch, epoch_loss, epoch_acc))
model.eval()
vrunning_loss = 0.0
vrunning_corrects = 0
num_samples = 0
for data, labels in valid_loader:
data = data.to(device)
labels = labels.to(device)
optimizer.zero_grad()
with torch.no_grad():
outputs = model(data)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
vrunning_loss += loss.item() * data.size(0)
vrunning_corrects += (preds == labels).sum()
num_samples += preds.size(0)
vepoch_loss = vrunning_loss/num_samples
vepoch_acc = (vrunning_corrects.double() * 100)/num_samples
print('\t\t Validation({}) - Loss: {:.4f}, Acc: {:.4f}'.format(epoch, vepoch_loss, vepoch_acc))
Hi @geetu were you able to use cross validation in pytorch?