RuntimeError: expected scalar type Double but found Float, input and model data types match

i am getting this error even after converting my model to float, while the input data are already float type. However, changing logits = model(x) to logits=model(x.float()) (in Train_utils.train_loop) seems to fix the issue.

I am interested in the root cause, since changing the model data type should fix it.

Model class

class Model(Module):
    def __init__(self):
        super(Model,self).__init__()
        self.stack = Sequential(Conv2d(1,8,3,padding='same', padding_mode='reflect'),  
                               MaxPool2d(2),
                               ReLU()                            
                               Conv2d(8,8,3,padding='same', padding_mode='reflect'),
                               MaxPool2d(2),
                               ReLU(),
                                BatchNorm2d(8),
                               )
        self.flatten = Flatten()
        self.linear = Linear(8*7*7,10)
        
    def forward(self,x):
        x = self.stack(x)
        x = self.flatten(x)
        logits = self.linear(x)
        
        return logits

Train-test loop without fix

class Train_utils:
    def __init__(self, model, train_dl, val_dl, optimizer, loss_fn=CrossEntropyLoss()):
        self.train_dl = train_dl
        self.val_dl = val_dl
#self.test_dl = test_dl
#self.num_epochs = num_epochs
        self.model = model
        self.loss_fn = loss_fn
        self.optimizer = optimizer(self.model.parameters(), lr=0.001)

    def train_loop(self,num_epochs):
        self.model.train()
        train_loss_cache=[]
        val_loss_cache=[]

        for i in trange(num_epochs):
            epoch_loss_cache=[]
            for x,y in self.train_dl:
                logits = self.model(x)
                loss = self.loss_fn(logits,y)
                
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                
                epoch_loss_cache.append(loss.detach().cpu())
            
            val_loss =[]
            self.model.eval()
            for x,y in val_dl:
                logits = self.model(x)
                loss = self.loss_fn(logits,y)
                val_loss.append(loss.detach().cpu())
            
            train_loss_cache.append(np.mean(epoch_loss_cache))
            val_loss_cache.append(np.mean(val_loss))
        return train_loss_cache, val_loss_cache
                
    def test(self, test_dl):
        counter=0
        size = len(test_dl.dataset)
        
        for x,y in test_dl:
            logits = self.model(x)
            eq = logits.argmax(1) == y.argmax(1)
            counter+= torch.sum(eq)
            
        accuracy = counter/size
        return accuracy

Data type of inputs

x,y = train_data_torch[0]
print(f'x data type : {x.dtype}')
print(f'y data type : {y.dtype}')

x data type : torch.float32
y data type : torch.float32

All the layers in the model have matching data types

state_dict = model_float.state_dict()
for key,value in state_dict.items():
    print(key,'---',value.dtype)

stack.0.weight --- torch.float32 stack.0.bias --- torch.float32 stack.3.weight --- torch.float32 stack.3.bias --- torch.float32 stack.6.weight --- torch.float32 stack.6.bias --- torch.float32 stack.6.running_mean --- torch.float32 stack.6.running_var --- torch.float32 stack.6.num_batches_tracked --- torch.int64 linear.weight --- torch.float32 linear.bias --- torch.float32

Check the dtype of x inside the DataLoader loop:

logits = self.model(x)

as it seems to be a float64 tensor.