train_indexes =
valid_indexes =
for (train_index, valid_index) in skf.split(train_loader.dataset.imgs,train_loader.dataset.targets):train_indexes.append(train_index) valid_indexes.append(valid_index) print(train_index,valid_index)
splits = zip(train_indexes, valid_indexes)
[ 451 452 453 … 21514 21515 21516] [ 0 1 2 … 18402 18403 18404]
[ 0 1 2 … 21514 21515 21516] [ 451 452 453 … 18847 18848 18849]
[ 0 1 2 … 21514 21515 21516] [ 902 903 904 … 19292 19293 19294]
[ 0 1 2 … 21514 21515 21516] [ 1353 1354 1355 … 19737 19738 19739]
[ 0 1 2 … 21514 21515 21516] [ 1803 1804 1805 … 20182 20183 20184]
[ 0 1 2 … 21514 21515 21516] [ 2253 2254 2255 … 20626 20627 20628]
[ 0 1 2 … 21514 21515 21516] [ 2703 2704 2705 … 21070 21071 21072]
[ 0 1 2 … 21070 21071 21072] [ 3153 3154 3155 … 21514 21515 21516]
EPOCHS = 5
SAVE_DIR = ‘models’
MODEL_SAVE_PATH = os.path.join(SAVE_DIR, ‘please.pt’)
from torch.utils.data import DataLoader
best_valid_loss = float(‘inf’)if not os.path.isdir(f’{SAVE_DIR}‘):
os.makedirs(f’{SAVE_DIR}‘)
print(“start”)
for epoch in range(EPOCHS):
print(’================================‘,epoch ,’================================')
for i , (train_idx, valid_idx) in enumerate(zip(train_indexes, valid_indexes)):
print(i,train_idx,valid_idx,len(train_idx),len(valid_idx))trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler= SubsetRandomSampler(train_idx)) valloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=SubsetRandomSampler(valid_idx)) print(len(trainloader.dataset),len(valloader.dataset)) train_loss, train_acc ,model= train(model, device, trainloader, optimizer, criterion) valid_loss, valid_acc,model = evaluate(model, device, valloader, criterion) if valid_loss < best_valid_loss: best_valid_loss = valid_loss torch.save(model,MODEL_SAVE_PATH) print(f'| Epoch: {epoch+1:02} | Train Loss: {train_loss:.3f} | Train Acc: {train_acc*100:05.2f}% | Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:05.2f}% |')
I found something strange
print(i,train_idx,valid_idx,len(train_idx),len(valid_idx)) = 18827 2690
and print(len(trainloader.dataset),len(valloader.dataset)) = 21517 21517
is my SubsetRandomSampler not working?
the output of