Hello everybody,
Over the last days I have encountered a very strange problem: my training stopped at the end of the training phase of the first epoch (it did not perform the validation step), without any errors.
The main part of my training code is shown below. I have tested multiple times, the code works well on a subset of my dataset. However, whenever I run it on the full dataset (that has over 700000 images with train/val ratio is 9/1), it stopped at the first training phase and never entered the validation phase.
I remember having a similar issue some time ago: training did not enter the validation phase and produced a segmentation fault error. However, this time, there was no errors (the last lines of the terminal output are shown after the code).
Could you please help to to find out what happened?
Thank you so much in advance for any suggestions!
for phase in ['train', 'val']:
print('Entering in phase:', phase)
if phase == 'train':
scheduler.step(best_acc)
model.train()
else:
model.eval()
running_loss = 0.0
running_corrects = 0
# Iterate over data.
print('Iterating over data:')
n_samples = 0
for batch_idx, (inputs, labels) in enumerate(dataloaders[phase]):
inputs = inputs.to(device)
labels = labels.to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward
# track history if only in train
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
# backward + optimize only if in training phase
if phase == 'train':
loss.backward()
optimizer.step()
# statistics
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
n_samples += len(labels)
if phase == 'train':
print('{}/{} Avg Loss: {:.4f} Avg Acc: {:.4f}'.format(n_samples, dataset_sizes[phase],
running_loss/n_samples, running_corrects.double()/n_samples))
print('Done iterating over data')
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects.double() / dataset_sizes[phase]
print('\t{} Loss: {:.4f}\t{} Acc: {:.4f}'.format(phase, epoch_loss, phase, epoch_acc))
# Here we are at the end of the 'val' phase
# check for improvement over the last epochs
if epoch_acc > best_acc:
print('Improved.')
best_acc = epoch_acc
try:
state_dict = model.module.state_dict()
except AttributeError:
state_dict = model.state_dict()
torch.save(state_dict, save_path)
632736/633656 Avg Loss: 0.0436 Avg Acc: 0.9868
632944/633656 Avg Loss: 0.0436 Avg Acc: 0.9868
633152/633656 Avg Loss: 0.0435 Avg Acc: 0.9868
633360/633656 Avg Loss: 0.0435 Avg Acc: 0.9868
633568/633656 Avg Loss: 0.0435 Avg Acc: 0.9868
633656/633656 Avg Loss: 0.0435 Avg Acc: 0.9868