Thanks very much for your reply. Yes, I have used Dataset. Here is the implemention:
import torch
import os
import numpy as np
from torch.utils.data import Dataset
#training dataset
class CNNDataset(Dataset):
def __init__(self, length, prefix, root_dir, transform=True):
self.length = length
self.prefix = prefix
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return int(self.length)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = os.path.join(self.root_dir, 'LCT9526_nm',
self.prefix+str(idx)+'.npy')
para_name = os.path.join(self.root_dir, 'LCTo9526_nm/',
self.prefix+str(idx)+'-y'+'.npy')
image = np.load(img_name)
para = np.load(para_name)
para[0] = para[0]/2
para[1] = para[1]/3
para[2] = para[2]/4
para[3] = para[3]/5
sample = [image, para]
if self.transform:
sample = self.transform(sample)
return sample
#validation dataset
class CNNDatasetv(Dataset):
def __init__(self, length, base, prefix, root_dir, transform=True):
self.length = length
self.base = base
self.prefix = prefix
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return int(self.length)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = os.path.join(self.root_dir, 'LCT9526_nm',
self.prefix+str(idx+self.base)+'.npy')
para_name = os.path.join(self.root_dir, 'LCTo9526_nm/',
self.prefix+str(idx+self.base)+'-y'+'.npy')
image = np.load(img_name)
para = np.load(para_name)
para[0] = para[0]/2
para[1] = para[1]/3
para[2] = para[2]/4
para[3] = para[3]/5
sample = [image, para]
if self.transform:
sample = self.transform(sample)
return sample
(Here for the structure of my data, I have to use Dataset twice for train and validation, respectively)
Here is how I generate data:
def create_datasets(batch_size):
trainset = CNNDataset(length=lenghtr, prefix = 'Idlt-',
root_dir='/scratch/zxs/',
transform=transforms.Compose([
ToTensor()
]))
validateset = CNNDatasetv(length=lenghva, base=base1, prefix = 'Idlt-',
root_dir='/scratch/zxs/',
transform=transforms.Compose([
ToTensor()
]))
testset = CNNDatasett(length=lenghte, base=base2, prefix = 'Idlt-',
root_dir='/scratch/zxs/',
transform=transforms.Compose([
ToTensor()
]))
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=0,pin_memory=True)
valid_loader = torch.utils.data.DataLoader(validateset, batch_size=1,
shuffle=True, num_workers=0, pin_memory=True)
test_loader = torch.utils.data.DataLoader(testset, batch_size=1,
shuffle=False, num_workers=0)
return train_loader, valid_loader, test_loader
And below is code for my training loop:
def train_model(patience, n_epochs, save_early_path, net, train_loader, valid_loader):
# to track the training loss as the model trains
train_losses = []
# to track the validation loss as the model trains
valid_losses = []
# to track the average training loss per epoch as the model trains
avg_train_losses = []
# to track the average validation loss per epoch as the model trains
avg_valid_losses = []
# initialize the early_stopping object
early_stopping = EarlyStopping(patience=patience, verbose=True)
for epoch in range(1, n_epochs + 1): # loop over the dataset multiple times
net.train()
for i, data in enumerate(train_loader, 0):
inputs, para = data[0].to(device,non_blocking=True), data[1].to(device,non_blocking=True)
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, para)
loss.backward()
optimizer.step()
train_losses.append(loss.item())
torch.cuda.empty_cache()
######################
# validate the model #
######################
net.eval()
with torch.no_grad():
for val in valid_loader:
inputsv, parav = val[0].to(device,non_blocking=True), val[1].to(device,non_blocking=True)
outputsv = net(inputsv)
loss = criterion(outputsv, parav)
valid_losses.append(loss.item())
train_loss = np.average(train_losses)
valid_loss = np.average(valid_losses)
avg_train_losses.append(train_loss)
avg_valid_losses.append(valid_loss)
scheduler.step(valid_loss)
print ('scheduler',sys.getsizeof(scheduler), flush=True)
epoch_len = len(str(n_epochs))
print_msg = (f'[{epoch:>{epoch_len}}/{n_epochs:>{epoch_len}}] ' +
f'train_loss: {train_loss:.5f} ' +
f'valid_loss: {valid_loss:.5f}')
print(print_msg)
# clear lists to track next epoch
train_losses = []
valid_losses = []
# early_stopping needs the validation loss to check if it has decresed,
# and if it has, it will make a checkpoint of the current model
state = {
'epoch': epoch,
'model_state_dict': net.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'valid_loss': valid_loss
}
early_stopping(valid_loss, state, save_early_path)
if early_stopping.early_stop:
print("Early stopping")
break
# load the last checkpoint with the best model
checkpoint = torch.load(save_early_path)
net.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return avg_train_losses, avg_valid_losses