I heard that model.eval() should be used during inference, I see it being used in validation data, so if I use for validation data, how I switch it off when I come back to training in next epoch?
Here is the code, before the validation loop, should I used it? Or should I use it, when every thing is done, and I am testing the test data?
from time import time
train_loss_plt=[]
val_loss_plt=[]
train_acc_plt=[]
val_acc_plt=[]
for ep in range(15):
start=time()
#start training loop
train_loss_sum=0
train_acc_sum=0
for train_batch in train_loader:
train_img,train_label=train_batch
train_img,train_label = train_img.to(device),train_label.to(device)
train_out=model(train_img)
train_loss=criterion(train_out,train_label)
opt.zero_grad()
train_loss.backward()
opt.step()
train_loss_sum+=train_loss.item()
train_acc_sum += ((torch.max(train_out,dim=1)[1]==train_label).sum().item())/len(train_label)
#start validation loop
val_loss_sum=0
val_acc_sum=0
#model.eval()
with torch.no_grad():
for val_image,val_label in val_loader:
val_image,val_label=val_image.to(device),val_label.to(device)
val_out=model(val_image)
val_loss=criterion(val_out,val_label)
val_loss_sum+=val_loss.item()
val_acc_sum+=((torch.max(val_out,dim=1)[1]==val_label).sum().item())/len(val_label)
end=np.round((time()-start)/60,2) #time in minute
#calculate print and append the results for plotting purpose
val_avg_loss=np.round(val_loss_sum/len(val_loader),2)#val loss of all batches of one epoch
train_avg_loss=np.round(train_loss_sum/len(train_loader),2)# train loss of all batches of one epoch
train_avg_acc=np.round(train_acc_sum/len(train_loader),2)#train acc of all batches of one epoch
val_avg_acc=np.round(val_acc_sum/len(val_loader),2)#val acc of all batches of one epoch
scheduler.step(val_avg_loss)
print('Epoch {}, time {} , train acc {}, train loss {} , val acc is {}, loss is {}, learning rate is {} '.format
(ep,end,train_avg_acc,train_avg_loss,val_avg_acc,val_avg_loss,opt.param_groups[0]['lr']))
train_loss_plt.append(train_avg_loss) #append loss of training data
val_loss_plt.append(val_avg_loss) #append loss of validation data
train_acc_plt.append(train_avg_acc) #append acc of training data
val_acc_plt.append(val_avg_acc) #append acc of validation data