Model train, val, test workflow verification in PyTorch

I was not sure where would be the best place to get a code review on a seemingly working piece of PyTorch code. Could you kindly please let me know if I am doing something wrongly perhaps? I was able to fix my previous problem of having test set accuracy stuck at 0 or 1. Now I get an accuracy on my test set around 70%. I just would like to get expert opinion if the code below is ok in terms of order of operations?

if train:
        torch.autograd.set_detect_anomaly(True)
        for i_batch, sample_batched in enumerate(dataloader_train):          
            feats, labels, _= prepare_features(sample_batched['image'], sample_batched['label'])
      
            loss_type='CrossEntropyLoss', phase='train')
            output = model(feats)
            loss = criterion(output, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            acc = (output.argmax(dim=1) == labels).float().mean()
            epoch_accuracy += acc / len(dataloader_train)
            epoch_loss += loss / len(dataloader_train)
 
    if not train:
        with torch.no_grad():
        
            epoch_val_accuracy = 0
            epoch_val_loss = 0
            epoch_val_preds = []
            epoch_val_labels = []
        
            #model.eval()
            print("evaluating...")
            total = 0.
            batch_idx = 0
            val_preds = []
            val_labels = []
            predictions = []
            actuals = []
            for i_batch, sample_batched in enumerate(dataloader_val):
                feats, labels, _ = prepare_features(sample_batched['image'], sample_batched['label']) 
                val_output = model(feats)
                val_loss = criterion(val_output, labels)
                acc = (val_output.argmax(dim=1) == labels).float().mean()
                epoch_val_accuracy += acc / len(dataloader_val)
                epoch_val_loss += val_loss / len(dataloader_val)
         epoch_val_preds.extend(val_output.argmax(dim=1).cpu().numpy())
                epoch_val_labels.extend(labels.cpu().numpy())


          
        if not test:
            if epoch_val_accuracy > best_val_acc:
                print('inside if - epoch is {}, val_acc is {}, and best_pred is {}'.format(epoch, epoch_val_accuracy, best_val_acc))
                best_val_acc = epoch_val_accuracy
                best_epoch = epoch
                best_preds = epoch_val_preds
                best_val_labels = epoch_val_labels
                print("Saving the best model...")
                torch.save(model.state_dict(), model_path + task_name + ".pth")

Please note that I use the same piece of code for validation and test phases but I run a different bash script with “test” flag with only one epoch for running. Also, if the flag is not set to “test” I do save the pth that pertains to the best validation accuracy. My model is a vanilla vision transformer that returns a 2D tensor (2 output values instead of 1) of 2xD where D is the batch size. I am using nn.CrossEntropyLoss for binary classification.

Most of my coding above is following the snippet below in code block #25. https://github.com/lucidrains/vit-pytorch/blob/main/examples/cats_and_dogs.ipynb

Sorry I am not fully able to fix the indentation in pytorch forum.

I think the code is correct (Except for some indentation pytorch forum problems (:- ).

But all the conditions in the code can be avoided by creating methods for a training step and for validation/testing.

Like :

def train_step(model, criterion, traindataloader, ...) :
	...
	for i_batch, sample_batched in enumerate(dataloader_train):
		...
		optimizer.step()
		...
	return loss, acc

def val_step(...) :
	...
	with torch.no_grad():
		...
		for i_batch, sample_batched in enumerate(dataloader_val):
			...
		...
		return acc, ...

# Training loop
for epoch in range(N_epochs) :
	...
	train_loss, train_acc = train_step(training data, ...)
	val_loss, val_acc = val_step(validation data, ...)
	...
	if epoch_val_accuracy > best_val_acc:
		...
	...

# Testing
test_loss, test_acc = val_step(test data, ...)
1 Like