How to solve this torch.max() error?

I am using YOLOV5 to train a binary classifier. After training the model I downloaded the weights to use it in my code:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import pandas as pd
from sklearn.model_selection import train_test_split
import shutil
import time
import copy
from PIL import Image
import glob
import cv2
from pygame import mixer

from models.experimental import attempt_load 
from utils.datasets import LoadStreams, LoadImages
from utils.general import check_img_size, non_max_suppression, apply_classifier, scale_coords, xyxy2xywh, \
    strip_optimizer, set_logging, increment_path
from utils.plots import plot_one_box
from utils.torch_utils import select_device, load_classifier, time_synchronized

# Load the model
filepath = 'weights/best.pt'
device = torch.device('cpu')
model = attempt_load(filepath, map_location = device)

# model = torch.load(filepath)

class_names = ['class1', ' class2']

mixer.init()
sound = mixer.Sound('files/alarm.wav')

def process_image(image):
    
    pil_image = image
    
    image_transformation = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
    img = image_transformation(pil_image)
    return img

def classify_face(image):
    device = torch.device("cpu")
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    im = Image.fromarray(image)
    image = process_image(im)
    print('Image processed')
    img = image.unsqueeze_(0)
    img = image.float()
    
    model.eval()
    model.cpu()
    output = model(image)
    print("#################")
    print(output)
    print("#################")
    _, predicted = torch.max(output, 1)
    print(predicted.data[0], "predicted")
    
    classification1 = predicted.data[0]
    index = int(classification1)
    print(class_names[index])
    return class_names[index]

cap = cv2.VideoCapture(0)
font = cv2.FONT_HERSHEY_COMPLEX_SMALL
score = 0
thicc = 2

while(True):
    ret, frame = cap.read()
    height, width = frame.shape[:2]
    label = classify_face(frame)
    if label == "class1":
        print("Recognition")
    else:
        sound.play()
        print("Beep")
    cv2.putText(frame, str(label), (100, height-20), font, 1, (255,255,255), 1, cv2.LINE_AA)
    cv2.imshow('frame', frame)
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break 

cap.release()
cv2.destroyAllWindows()

So after executing this code I get the following error:

Fusing layers...
Image processed
#################
(tensor([[[6.69752e+00, 8.04415e+00, 3.28937e+01,  ..., 2.50250e-01, 8.44748e-01, 6.81155e-02],
         [1.33582e+01, 8.66212e+00, 3.57769e+01,  ..., 2.49377e-01, 9.11782e-01, 5.36799e-02],
         [2.17681e+01, 8.39124e+00, 3.61725e+01,  ..., 1.95423e-01, 9.29966e-01, 4.44486e-02],
         ...,
         [1.26728e+02, 2.13414e+02, 2.50512e+02,  ..., 1.54672e-03, 4.71090e-01, 6.19469e-01],
         [1.65840e+02, 2.11068e+02, 2.44700e+02,  ..., 3.60577e-03, 6.32448e-01, 4.47742e-01],
         [1.97628e+02, 1.99194e+02, 1.87245e+02,  ..., 2.99891e-02, 5.68435e-01, 4.69893e-01]]]), [tensor([[[[[ 7.01838e-01,  1.11338e+00,  2.27554e+00,  ..., -1.09728e+00,  1.69399e+00, -2.61600e+00],
.
.
.
.

         [[[-9.98679e-01,  1.22741e+00,  3.44763e-02,  ..., -1.80213e+01, -2.14888e+00,  9.24265e+00],
           [-2.50130e+00,  1.01704e+00,  1.38206e-01,  ..., -2.42873e+01,  1.88566e-02,  9.59896e+00],
           [-2.42535e+00,  2.87720e-01,  1.67901e-01,  ..., -2.35991e+01, -7.32914e-01,  1.00338e+01],
           ...,
           [-1.85729e+00,  9.89787e-02,  1.70667e-02,  ..., -1.84327e+01,  3.36290e-01,  6.32527e+00],
           [-1.38407e+00,  1.97428e-01, -6.90677e-02,  ..., -1.46897e+01,  1.04981e+00,  4.01571e+00],
           [-8.79307e-01, -2.48860e-01, -2.99195e-01,  ..., -8.85421e+00,  1.25883e+00,  7.27224e-01]],

          [[-2.68999e+00,  2.96068e+00,  1.58570e-01,  ..., -2.82125e+01,  8.48523e-01,  1.34653e+01],
           [-4.10724e+00,  2.13760e+00,  3.61649e-01,  ..., -3.43560e+01,  6.56664e+00,  1.07199e+01],
           [-4.23445e+00,  8.68277e-01,  4.84210e-01,  ..., -3.45495e+01,  6.51605e+00,  1.09539e+01],
           ...,
           [-3.44668e+00,  4.01725e-01,  3.10529e-01,  ..., -2.52625e+01,  4.36256e+00,  7.46622e+00],
           [-2.37642e+00,  2.04813e-01,  1.78175e-01,  ..., -1.82195e+01,  3.01713e+00,  4.74489e+00],
           [-1.33341e+00, -3.53624e-01, -2.47158e-01,  ..., -1.18629e+01,  5.18395e-01,  3.28432e+00]],

          [[-3.01659e+00,  4.08582e+00,  1.84225e-01,  ..., -3.20945e+01,  1.73300e+00,  1.47701e+01],
           [-4.09202e+00,  3.22706e+00,  5.14679e-01,  ..., -3.99877e+01,  8.27533e+00,  1.20588e+01],
           [-4.40775e+00,  2.37793e+00,  8.85271e-01,  ..., -4.28225e+01,  1.12920e+01,  1.06002e+01],
           ...,
           [-3.21899e+00,  6.00486e-01,  5.33032e-01,  ..., -2.79490e+01,  7.23887e+00,  5.49865e+00],
           [-1.86446e+00,  3.16312e-01,  1.72526e-01,  ..., -1.91241e+01,  3.27271e+00,  4.72650e+00],
           [-1.28665e+00, -5.96457e-01, -2.56735e-01,  ..., -1.18503e+01,  1.95543e-01,  3.55386e+00]],

          ...,

          [[-3.18089e+00,  4.74764e+00,  3.08528e-01,  ..., -3.67707e+01,  4.93246e+00,  1.43625e+01],
           [-3.82487e+00,  3.55594e+00,  7.72573e-01,  ..., -4.49166e+01,  1.31486e+01,  9.88138e+00],
           [-2.27103e+00,  1.16540e+00,  7.74492e-01,  ..., -4.00071e+01,  1.10363e+01,  8.79684e+00],
           ...,
           [-1.56786e+00,  1.75814e-01, -3.25143e-03,  ..., -1.74488e+01,  2.27494e+00,  5.53798e+00],
           [-1.38396e+00,  2.08026e-01, -1.17126e-01,  ..., -1.32549e+01,  2.37821e+00,  3.00369e+00],
           [-1.30192e+00, -7.80406e-01, -3.65586e-01,  ..., -8.80362e+00, -6.79248e-02,  2.24451e+00]],
          ...,

          [[-6.45747e+00,  5.52765e+00, -6.96355e-02,  ..., -2.89378e+01,  1.76833e+00,  6.23220e+00],
           [-9.18074e+00,  5.90005e+00,  1.11971e+00,  ..., -3.84686e+01,  7.98066e+00,  2.76979e+00],
           [-7.06971e+00,  4.38462e+00,  1.45152e+00,  ..., -3.41044e+01,  6.05271e+00,  3.31685e+00],
           ...,
           [-2.54191e+00,  1.53113e+00,  1.44606e-02,  ..., -1.48724e+01,  6.23257e-01,  2.63972e+00],
           [-2.62895e+00,  1.35553e+00, -1.88812e-01,  ..., -1.17762e+01,  8.47540e-01,  1.90525e+00],
           [-2.10761e+00, -2.75716e-01, -3.36237e-01,  ..., -6.97977e+00, -4.41530e-01,  1.87154e+00]],

          [[-5.89622e+00,  4.53648e+00,  1.84505e-02,  ..., -2.80769e+01,  3.20636e+00,  4.17765e+00],
           [-7.71760e+00,  4.57286e+00,  1.15239e+00,  ..., -3.59879e+01,  8.78807e+00,  5.63405e-01],
           [-5.05829e+00,  3.16673e+00,  1.08323e+00,  ..., -2.70402e+01,  4.91386e+00,  2.79420e+00],
           ...,
           [-2.26254e+00,  8.64523e-01, -3.60977e-01,  ..., -1.08126e+01,  8.49348e-01,  9.47680e-01],
           [-2.00815e+00,  7.68478e-01, -3.30762e-01,  ..., -9.98294e+00,  1.89936e+00, -1.13327e-01],
           [-1.20806e+00, -4.77093e-01, -3.70990e-01,  ..., -6.13591e+00,  4.18642e-01,  5.03567e-01]],

          [[-3.65250e+00,  2.69710e+00,  3.47386e-01,  ..., -2.28330e+01,  5.96863e+00, -1.50741e+00],
           [-3.48414e+00,  3.47050e+00,  9.23297e-01,  ..., -2.72244e+01,  8.49592e+00, -2.85586e+00],
           [-8.46303e-01,  1.84489e+00,  2.89578e-01,  ..., -1.80682e+01,  4.94311e+00, -6.00901e-01],
           ...,
           [-1.20757e+00,  3.41630e-01, -3.64955e-01,  ..., -6.47007e+00, -1.15771e-01,  4.87296e-01],
           [-6.57739e-01,  1.92338e-01, -3.84763e-01,  ..., -5.62161e+00,  5.42731e-01, -2.09800e-01],
           [-6.72526e-01, -5.64906e-01, -6.00370e-01,  ..., -3.47647e+00,  2.75467e-01, -1.20575e-01]]]]])])
#################
Traceback (most recent call last):
  File "WEBCAM_DETECT.py", line 94, in <module>
    label = classify_face(frame)
  File "WEBCAM_DETECT.py", line 78, in classify_face
    _, predicted = torch.max(output, 1)
TypeError: max() received an invalid combination of arguments - got (tuple, int), but expected one of:
 * (Tensor input)
 * (Tensor input, name dim, bool keepdim, *, tuple of Tensors out)
 * (Tensor input, Tensor other, *, Tensor out)
 * (Tensor input, int dim, bool keepdim, *, tuple of Tensors out)

[ WARN:0] global C:\Users\appveyor\AppData\Local\Temp\1\pip-req-build-h4wtvo23\opencv\modules\videoio\src\cap_msmf.cpp (435) `anonymous-namespace'::SourceReaderCB::~SourceReaderCB terminating async callback

Please let me know how can I solve this error?

Hi, as error message suggest the first argument to _, predicted = torch.max(output, 1) appeared to be tuple, but it should be tensor. Printing the output also shows it is a tuple (tensor([[[6.6 given the preceding (. Your model gives the tuple as an output. You can check with type(output) and len(output). You need to figure out what element of the tuple you need for the program if the length of output tuple is greater than 1, or just index it like output[0] if there is only one element in the tuple.

1 Like