Model train, val, test workflow verification in PyTorch

I was not sure where would be the best place to get a code review on a seemingly working piece of PyTorch code. Could you kindly please let me know if I am doing something wrongly perhaps? I was able to fix my previous problem of having test set accuracy stuck at 0 or 1. Now I get an accuracy on my test set around 70%. I just would like to get expert opinion if the code below is ok in terms of order of operations?

if train:
        for i_batch, sample_batched in enumerate(dataloader_train):          
            feats, labels, _= prepare_features(sample_batched['image'], sample_batched['label'])
            loss_type='CrossEntropyLoss', phase='train')
            output = model(feats)
            loss = criterion(output, labels)
            acc = (output.argmax(dim=1) == labels).float().mean()
            epoch_accuracy += acc / len(dataloader_train)
            epoch_loss += loss / len(dataloader_train)
    if not train:
        with torch.no_grad():
            epoch_val_accuracy = 0
            epoch_val_loss = 0
            epoch_val_preds = []
            epoch_val_labels = []
            total = 0.
            batch_idx = 0
            val_preds = []
            val_labels = []
            predictions = []
            actuals = []
            for i_batch, sample_batched in enumerate(dataloader_val):
                feats, labels, _ = prepare_features(sample_batched['image'], sample_batched['label']) 
                val_output = model(feats)
                val_loss = criterion(val_output, labels)
                acc = (val_output.argmax(dim=1) == labels).float().mean()
                epoch_val_accuracy += acc / len(dataloader_val)
                epoch_val_loss += val_loss / len(dataloader_val)

        if not test:
            if epoch_val_accuracy > best_val_acc:
                print('inside if - epoch is {}, val_acc is {}, and best_pred is {}'.format(epoch, epoch_val_accuracy, best_val_acc))
                best_val_acc = epoch_val_accuracy
                best_epoch = epoch
                best_preds = epoch_val_preds
                best_val_labels = epoch_val_labels
                print("Saving the best model...")
      , model_path + task_name + ".pth")

Please note that I use the same piece of code for validation and test phases but I run a different bash script with “test” flag with only one epoch for running. Also, if the flag is not set to “test” I do save the pth that pertains to the best validation accuracy. My model is a vanilla vision transformer that returns a 2D tensor (2 output values instead of 1) of 2xD where D is the batch size. I am using nn.CrossEntropyLoss for binary classification.

Most of my coding above is following the snippet below in code block #25.

Sorry I am not fully able to fix the indentation in pytorch forum.

I think the code is correct (Except for some indentation pytorch forum problems (:- ).

But all the conditions in the code can be avoided by creating methods for a training step and for validation/testing.

Like :

def train_step(model, criterion, traindataloader, ...) :
	for i_batch, sample_batched in enumerate(dataloader_train):
	return loss, acc

def val_step(...) :
	with torch.no_grad():
		for i_batch, sample_batched in enumerate(dataloader_val):
		return acc, ...

# Training loop
for epoch in range(N_epochs) :
	train_loss, train_acc = train_step(training data, ...)
	val_loss, val_acc = val_step(validation data, ...)
	if epoch_val_accuracy > best_val_acc:

# Testing
test_loss, test_acc = val_step(test data, ...)
1 Like