i am getting this error even after converting my model to float, while the input data are already float type. However, changing logits = model(x)
to logits=model(x.float())
(in Train_utils.train_loop) seems to fix the issue.
I am interested in the root cause, since changing the model data type should fix it.
Model class
class Model(Module):
def __init__(self):
super(Model,self).__init__()
self.stack = Sequential(Conv2d(1,8,3,padding='same', padding_mode='reflect'),
MaxPool2d(2),
ReLU()
Conv2d(8,8,3,padding='same', padding_mode='reflect'),
MaxPool2d(2),
ReLU(),
BatchNorm2d(8),
)
self.flatten = Flatten()
self.linear = Linear(8*7*7,10)
def forward(self,x):
x = self.stack(x)
x = self.flatten(x)
logits = self.linear(x)
return logits
Train-test loop without fix
class Train_utils:
def __init__(self, model, train_dl, val_dl, optimizer, loss_fn=CrossEntropyLoss()):
self.train_dl = train_dl
self.val_dl = val_dl
#self.test_dl = test_dl
#self.num_epochs = num_epochs
self.model = model
self.loss_fn = loss_fn
self.optimizer = optimizer(self.model.parameters(), lr=0.001)
def train_loop(self,num_epochs):
self.model.train()
train_loss_cache=[]
val_loss_cache=[]
for i in trange(num_epochs):
epoch_loss_cache=[]
for x,y in self.train_dl:
logits = self.model(x)
loss = self.loss_fn(logits,y)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
epoch_loss_cache.append(loss.detach().cpu())
val_loss =[]
self.model.eval()
for x,y in val_dl:
logits = self.model(x)
loss = self.loss_fn(logits,y)
val_loss.append(loss.detach().cpu())
train_loss_cache.append(np.mean(epoch_loss_cache))
val_loss_cache.append(np.mean(val_loss))
return train_loss_cache, val_loss_cache
def test(self, test_dl):
counter=0
size = len(test_dl.dataset)
for x,y in test_dl:
logits = self.model(x)
eq = logits.argmax(1) == y.argmax(1)
counter+= torch.sum(eq)
accuracy = counter/size
return accuracy
Data type of inputs
x,y = train_data_torch[0]
print(f'x data type : {x.dtype}')
print(f'y data type : {y.dtype}')
x data type : torch.float32
y data type : torch.float32
All the layers in the model have matching data types
state_dict = model_float.state_dict()
for key,value in state_dict.items():
print(key,'---',value.dtype)
stack.0.weight --- torch.float32 stack.0.bias --- torch.float32 stack.3.weight --- torch.float32 stack.3.bias --- torch.float32 stack.6.weight --- torch.float32 stack.6.bias --- torch.float32 stack.6.running_mean --- torch.float32 stack.6.running_var --- torch.float32 stack.6.num_batches_tracked --- torch.int64 linear.weight --- torch.float32 linear.bias --- torch.float32