Mismatch in data shapes

As I’m trying to run a training loop and training model on batches of data, I get this error message “ValueError: Expected input batch_size (16) to match target batch_size (32).” I

Below is my code :

#import tqdm for progress bar
from tqdm import tqdm
from tqdm.auto import tqdm

#set the seed and start the timer
torch.manual_seed(42)
train_time_start_on_cpu = timer()

#set the number of epochs
epochs = 10

#create a training and test loop
for epoch in tqdm(range(epochs)):
print(f"Epoch: {epoch}\n------")

#training
train_loss = 0
#add a loop to loop through the training batches
for batch, (X,y) in enumerate(train_dataloader):
    m_0.train()

    #forward pass
    y_pred = m_0(X)

    #calculate the loss per batch
    loss = loss_f(y_pred, y)
    train_loss += loss #to accumulate the train loss

    #optimize the zero grad
    optim.zero_grad()

    #loss backward // backpropagation step
    loss.backward()

    #optimizer step 
    optim.step()

    #print out what's happening 
    if batch %400 == 0:
        print(f" Looked at {batch * len(X)}/{len(train_dataloader.dataset)} samples")

#divide total train loss by lenght of train dataloader
train_loss /= len(train_dataloader)

#testing 
test_loss, test_acc = 0, 0
m_0.eval()
with torch.inference_mode():
    for X_test, y_test in test_dataloader:

        #forward pass
        test_pred = m_0(X_test)

        #calculate the loss accumulatively
        test_loss = loss_f(test_pred, y_pred)

        #calculate the accuracy
        test_acc = accuracy_fn(y_true=y_test, y_pred=test_pred.argmax(dim=1))

    #calculate the test loss average per batch
    test_loss /= len(test_dataloader)

    #calculate the test acc average per batch
    test_acc /= len(test_dataloader)

#print out what's happening 
print(f"Train loss : {train_loss:.4f} | Test loss : {test_loss:.4f} | Test acc : {test_acc:.4f}")

#calculate the training time
train_time_end_on_cpu = timer()
total_train_time_m_0 = print_train_time(start=train_time_start_on_cpu,
end=train_time_end_on_cpu,
device=str(next(m_0.parameters()).device)
)

Could someone please help me ?