Thanks a lot. I was calculating the validation loss after each epoch using the following code and was using the value of val loss to update the scheduler. Will this be considered as a data leakage?
def Solver_NN(model, train_loader, dev_loader, optim, criterion, device, scheduler, print_every=10, epoch=51, lr=1e-1):
print("Solver Initiated")
model=model.to(device) # sending model to GPU
print("Model successfully sent to the GPU\n")
print_every=print_every
total_step = len(train_loader)
counter=0
for e in range(epoch):
running_loss = 0.0
epoch_loss=0.0
for i, (x,y) in enumerate(train_loader):
optim.zero_grad()
X=x
#print(X.shape)
X=X.to(device=device, dtype=torch.float)
y=y.to(device=device, dtype=torch.long)
#forward pass########
y_pred=model(X)
loss=criterion(y_pred, y)
####################
# backward pass#######
loss.backward()
optim.step()
####################
running_loss += loss.item()
epoch_loss+=running_loss
if (i+1) % print_every == 0: # print every 10
print ("Epoch [{}/{}], Step [{}/{}] Loss: {}".format(counter+1, epoch, i+1, total_step, running_loss/print_every))
running_loss = 0.0
counter+=1
with torch.no_grad():
train_hter, train_loss=HTER(model=model, loss_criterion=criterion, loader=train_loader)
dev_hter, dev_loss=HTER(model=model, loss_criterion=criterion, loader=dev_loader)
print(f"Train loss in epoch {e+1} is {(train_loss)} and Train HTER in epoch {e+1}: {train_hter}")
print(f"Dev loss in epoch {e+1} is {(dev_loss)} and Dev HTER in epoch {e+1}: {dev_hter}")
scheduler.step(dev_loss)
torch.save(model.state_dict(), sys_path + 'Codes/Replay Attack/weights_replay_attack_1FPS/weights1/Resnet_Replay_attack_No_LBP_'+ str(e+1) + "_" + str(np.floor(dev_hter))+'.pkl')
print("Model saved successfully!\n")
return model
This is the HTER function used inside the with torch.no_grad() block:
def HTER(model, loss_criterion, loader):
Ys=list()
Y_preds=list()
Loss_total=0
total_step=len(loader)
for touple, label in loader:
batch=touple
batch=batch.to(device=device, dtype=torch.float)
label=label.to(device=device, dtype=torch.long)
Ys.append(label)
Y_pred=model(batch)
Y_preds.append(torch.argmax(Y_pred, dim=1))
Loss=loss_criterion(Y_pred,label)
Loss_total += Loss.item()
Y= torch.cat(Ys, dim=0)
Y_pred=torch.cat(Y_preds, dim=0)
tp, tn, fp, fn =confusion_matrix(Y, Y_pred)
hter=1-(0.5*((tp/(tp+fn))+(tn/(tn+fp))))
return (hter*100, Loss_total/total_step)