Inference of pytorch model

Hi guys,
I tried to find solution by myself but not enough experience.
I have such inference code for prediction:

test_loader = DataLoader(dataset=test_dataset, batch_size=1)

y_pred_list = []

with torch.inference_mode():
    model.eval()
    for X_batch in test_loader:
        X_batch = X_batch.to(device)
        y_test_pred = model(X_batch)        
        _, y_pred_tags = torch.max(y_test_pred, dim = 1)        
        y_pred_list.append(y_pred_tags.cpu().numpy())

y_pred_list = [a.squeeze().tolist() for a in y_pred_list]

The inference ‘y_test_pred’ gives tensor with 6 possibilities and torch.max takes max out of them. If i put batch_size=1 everything is working great but very slow since the data is very huge.
If I put batch_size=32 it is working fast but the inference ‘y_test_pred’ comes with additional dimension(32) and I can’t understand how to squeeze it or maybe to make torch.max later. Please, help.

What’s the error that you are facing when doing torch.max on a batch of tensors?

Also, please post the exact shape of y_test_pred.

There is no error. If I put batch_size=1, the inference shape is [1,6] and ‘torch.max’ takes max value out of 6. But if I put batch_size=32, the inference shape is already [32,6] and I don’t understand how to make ‘torch.max’ from this shape.

See if this helps -

x = torch.randn(32, 6) 
max_values, max_indices = torch.max(x, dim=1) # takes the maximum from each batch
print(max_values.shape, max_indices.shape)

gives

(torch.Size([32]), torch.Size([32]))

I don’t understand. I redid everything step by step and it started to work even with batch 32. Magic. Thank you very much. It is second time you help me.