I am trying to use transfer learning to train an nf-net f4 model for steganography detection as part of an assignment but training is taking too long. Currently, training one epoch takes about 1 hour.
Here is the model definition:
# create model
model = timm.create_model(model_name,pretrained=True)
# freeze previous layers (for transfer learning)
for param in model.parameters():
param.requires_grad = False
# get number of input features for last layer
num_features = model.head.fc.in_features
# modify last layer
# 2 = modified, not modified
last_layer = nn.Linear(num_features,4)
# set last layer
model.head.fc = last_layer
return model
The training loop is:
# get dataloaders
_,dataloaders, datasets_size,_ = load_datasets(model_name,batch_size,num_workers)
# to plot losses
train_losses = []
val_losses = []
# timer
since = time.time()
# best accuracy
best_acc = 0.0
# best model weights
best_model_wts = copy.deepcopy(model.state_dict())
# training loop
for epoch in range(num_epochs):
epoch_start_time = time.time()
with open(log_file_path,"a",newline='') as log_file:
log_file.write('Epoch {}/{} \n'.format(epoch, num_epochs - 1))
log_file.write("---------- \n")
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
# train and validation phase
for phase in ['train','val']:
# set model mode
if phase == 'train':
model.train()
else:
model.eval()
# current loss and correct accumulator (used for metrics)
current_loss = 0.0
current_corrects = 0
current_training_loss = 0.0
current_val_loss = 0.0
print("Iterating through data")
# iterate over data
for inputs, labels in dataloaders[phase]:
# get inputs
inputs = inputs.to(device)
labels = labels.to(device)
# reset grad
optimizer.zero_grad()
# perform prediction
with torch.set_grad_enabled(phase == 'train'):
# prediction
outputs = model(inputs)
# prediction
_, preds = torch.max(outputs,1)
# loss
loss = criterion(outputs,labels)
if phase == 'train':
# back propagation
loss.backward()
optimizer.step()
# compute performance metrics
current_loss += loss.item() * inputs.size(0)
current_corrects += torch.sum(preds == labels)
if phase == 'train':
scheduler.step()
# calculate epoch metrics
epoch_loss = current_loss / datasets_size[phase]
epoch_acc = current_corrects.double() /datasets_size[phase]
# store losses values
if phase == 'train':
current_training_loss = epoch_loss
train_losses.append(epoch_loss)
else:
current_val_loss = epoch_loss
val_losses.append(epoch_loss)
if phase == 'val':
with open(all_epoch_metrics,"a",newline='') as temp_file:
write_data = f"{epoch},{epoch_acc},{current_val_loss},{current_training_loss}\n"
temp_file.write(write_data)
epoch_end_time = time.time()
epoch_time_interval = epoch_end_time - epoch_start_time
# {:.0f}m {:.0f}s".format(time_since // 60, time_since %60)
with open(log_file_path,"a",newline='') as log_file:
log_file.write("{} \n".format(phase))
log_file.write("start time: {:} seconds \n".format(epoch_start_time))
log_file.write("end time: {:} seconds \n".format(epoch_end_time))
log_file.write("interval: {:.0f} minutes {:.0f} seconds \n".format(epoch_time_interval // 60, epoch_time_interval % 60))
log_file.write("{} Loss: {:.4f} Acc: {:.4f} \n".format(phase,epoch_loss,epoch_acc))
print("Start time: {} seconds".format(epoch_start_time))
print("End Time: {} seconds".format(epoch_end_time))
print("interval: {:.0f} min {:.0f} sec".format(epoch_time_interval // 60, epoch_time_interval % 60))
print("{} Loss: {:.4f} Acc: {:.4f}".format(phase,epoch_loss,epoch_acc))
# check if validation performance improved
if phase == 'val' and epoch_acc > best_acc:
# set new best accuracy
best_acc = epoch_acc
# save model weights
best_model_wts = copy.deepcopy(model.state_dict())
# save model values
current_path = f"./best_model_{epoch}.pth"
with open(best_model_metrics_path,"a",newline='') as best_model_metrics_file:
data_to_write = f"{current_path},{best_acc}\n"
best_model_metrics_file.write(data_to_write)
# save model
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'current_loss': current_loss,
'current_corrects': current_corrects,
'loss': criterion
},current_path)
# check if time to save model
if epoch % checkpoint_epoch_num == 0 and phase == 'val':
# save model
current_path = f"./checkpoint_model_{epoch}.pth"
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'current_loss': current_loss,
'current_corrects': current_corrects,
'loss': criterion
},current_path)
print()
time_since = time.time() - since
with open(log_file_path,"a",newline='') as log_file:
log_file.write("Training complete in {:.0f}m {:.0f}s \n".format(time_since // 60, time_since %60))
log_file.write("Best val acc: {:.4f} \n".format(best_acc))
print("Training complete in {:.0f}m {:.0f}s".format(time_since // 60, time_since %60))
print("Best val acc: {:.4f}".format(best_acc))
I am training on Kaggle’s free GPU with the following parameters:
# call driver program above
main(
model_name="dm_nfnet_f6",
batch_size=32, # 32,64,128,256,512,1024
num_workers=2, #constant
learning_rate=1e-3, # 1e-2,1e-3,1e-4,1e-5,1e-6,1e-7
plot_loss=True,
num_epochs=1, # 20 for testing
step_size=5,
checkpoint_epoch_num=1 # due to low number of epochs executed
)
I am training on a train dataset contains 29432 images and the test set contains 9800.
Is there any way to speed up training? Are there any issues with the code that might be slowing it down?