I am encountering some trouble with the dtypes of the tensors used when trying to train a model with pretrained ResNet18 backbone to do localization (classification and bounding box prediction of images). I use a ResNet18 backbone with 2 heads, both heads are simple fully connected layers. The first head has 120 output features, one for each class in the dataset, the other has 4 output features, one for each coordinate of the bounding box. The loss function used on the output from the first head is cross entropy loss, and for the second head I use mean squared error loss.
When a batch of images, labels, and bounding boxes are loaded using a dataloader, the dtypes of the labels and bounding boxes are int64. In the backpropagation step of the training loop, I get the following error: “RuntimeError: Found dtype Long but expected Float”. I figured this was due to the dtypes of the ground truth labels and bounding boxes, so I cast the tensors to type float32 using .float().
However, if I do this cast before the calculation of the loss from the first head using cross entropy, I get the follow error: “RuntimeError: expected scalar type Long but found Float”. This makes sense, as cross entropy is categorical so it expects a int type, and not a float type. However, if I move the cast to after the losses have been calculated, I get the original error again: “RuntimeError: Found dtype Long but expected Float”. There must be something I am misunderstanding when using backprop or the model, as I cannot see where a tensor with dtype Long is being used.
I have attached the training loop and the model I am using below, together with the attributes needed to understand the training loop. If more code is needed to provide context I will gladly supply it, but figured I should not make the post too long.
Model used:
from einops import reduce
import torch.nn as nn
from torchvision.models import resnet18, ResNet18_Weights
class ResNet18Model(nn.Module):
def __init__(self, pretrained=True, freeze_backbone=True):
super().__init__()
weights = None
if pretrained:
weights = ResNet18_Weights.IMAGENET1K_V1
self.backbone = nn.Sequential(*list(resnet18(weights=weights).children())[:-2])
if freeze_backbone:
for parameters in self.backbone.named_parameters():
parameters[1].requires_grad = False
self.classification_head = nn.Linear(512, 120)
self.box_head = nn.Linear(512, 4)
def forward(self, x):
backbone_features = self.backbone(x)
backbone_features = reduce(backbone_features, 'b c h w -> b c', reduction='mean')
classification = self.classification_head(backbone_features)
box = self.box_head(backbone_features)
return classification, box
Context for training loop:
self.loss_classification = F.cross_entropy
self.loss_localization = F.mse_loss
self.optimizer = SGD(model.parameters(), lr=0.001)
self.learning_rate_scheduler = def base_lr_scheduler(t, T, lr): return lr
Training loop that produces error:
def train(self):
for epoch in range(self.epochs):
self.model.train()
for x_batch, class_batch, box_batch in self.train_loader:
box_batch = torch.squeeze(box_batch)
# Update learning rate
self.optimizer.param_groups[0]['lr'] = self.learning_rate_scheduler(
self.current_batch_index, self.total_batches, lr=self.optimizer.param_groups[0]['lr'])
# Forward pass
prediction_class, prediction_box = self.model(x_batch)
print('Prediction class shape:', prediction_class.shape)
print('Prediction box shape:', prediction_box.shape)
print('Prediction class type:', prediction_class.dtype)
print('Prediction box type:', prediction_box.dtype)
print('Batch class shape:', class_batch.shape)
print('Batch box shape:', box_batch.shape)
print('Batch class type:', class_batch.dtype)
print('Batch box type:', box_batch.dtype)
loss_class = self.loss_classification(prediction_class, class_batch)
loss_box = self.loss_localization(prediction_box, box_batch)
total_loss = loss_class + loss_box
print('Loss class type', loss_class.dtype)
print('Loss box type', loss_box.dtype)
print('Total loss type', total_loss.dtype)
class_batch = class_batch.float()
box_batch = box_batch.float()
print('Batch class type after .float():', class_batch.dtype)
print('Batch box type after .float():', box_batch.dtype)
# Backprop
total_loss.backward()
# Update model parameters
self.optimizer.step()
self.optimizer.zero_grad()
The code above returns the following when training the model:
Prediction class shape: torch.Size([64, 120])
Prediction box shape: torch.Size([64, 4])
Prediction class type: torch.float32
Prediction box type: torch.float32
Batch class shape: torch.Size([64])
Batch box shape: torch.Size([64, 4])
Batch class type: torch.int64
Batch box type: torch.int64
Loss class type torch.float32
Loss box type torch.float32
Total loss type torch.float32
Batch class type after .float(): torch.float32
Batch box type after .float(): torch.float32
Traceback (most recent call last):
File "localization_test.py", line 33, in <module>
dog_localizer.train()
File "/home/nmunch/computer_vision_project/dog_localization_project/dog_localization_utilities/dog_localization_utilities/localization.py", line 113, in train
total_loss.backward()
File "/home/nmunch/environments/computer_vision/lib/python3.8/site-packages/torch/_tensor.py", line 492, in backward
torch.autograd.backward(
File "/home/nmunch/environments/computer_vision/lib/python3.8/site-packages/torch/autograd/__init__.py", line 251, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: Found dtype Long but expected Float
I feel like there is something obvious I am missing so I hope any of you can help me. Thanks in advance!