Based on the error message it seems as if the input tensor use the uint8
dtype
while the model uses the expected float32
dtype
. Note that PILToTensor
will keep the same dtype
of the input image which is most likely causing the issue. Use ToTensor()
to normalize the input image and return it in float32
format, which should fix the error.
2 Likes