Expected scalar type Double but found Float

Hi,

Running inference using a state_dict from training on modified mobilenet_v2, I get a RuntimeError: expected scalar type Double but found Float message even though my input image is of type torch.double. What might be causing the type mismatch if not the input image itself? Below is my inference setup:

model = mobilenet_v2()
num_classes = 16
model.classifier[1] = torch.nn.Linear(1280, num_classes)
model.load_state_dict(torch.load(path.join(...)
...
image = torch.as_tensor(np.transpose(image, (2, 0, 1)), dtype=torch.double, device=None)
image = image[None, : ]
pred = model(image)

...
     
     return F.conv2d(input, weight ...)
>> RuntimeError: expected scalar type Double but found Float

Ok… when I do model = modelnet_v2().float() and change the input image tensor to dtype=torch.float, the inference actually runs.

BUT, when both the model and image are set to double, it gives a RuntimeError: expected scalar type Float but found Double. After reading this thread, I still don’t see the pattern.

How does one interpret such behaviors?

Both formats work for me:

# float32
dtype = torch.float32
model = models.mobilenet_v2()
num_classes = 16
model.classifier[1] = torch.nn.Linear(1280, num_classes)
model.to(dtype)
image = torch.randn(1, 3, 224, 224, dtype=dtype)

pred = model(image) # works
print(pred.dtype)
# torch.float32

# float64
dtype = torch.float64
model = models.mobilenet_v2()
num_classes = 16
model.classifier[1] = torch.nn.Linear(1280, num_classes)
model.to(dtype)
image = torch.randn(1, 3, 224, 224, dtype=dtype)

pred = model(image) # works
print(pred.dtype)
# torch.float64

so could you post a code snippet reproducing the issue?