I am encountering some trouble with the dtypes of the tensors used when trying to train a model with pretrained ResNet18 backbone to do localization (classification and bounding box prediction of images). I use a ResNet18 backbone with 2 heads, both heads are simple fully connected layers. The first head has 120 output features, one for each class in the dataset, the other has 4 output features, one for each coordinate of the bounding box. The loss function used on the output from the first head is cross entropy loss, and for the second head I use mean squared error loss.
When a batch of images, labels, and bounding boxes are loaded using a dataloader, the dtypes of the labels and bounding boxes are int64. In the backpropagation step of the training loop, I get the following error: “RuntimeError: Found dtype Long but expected Float”. I figured this was due to the dtypes of the ground truth labels and bounding boxes, so I cast the tensors to type float32 using .float().
However, if I do this cast before the calculation of the loss from the first head using cross entropy, I get the follow error: “RuntimeError: expected scalar type Long but found Float”. This makes sense, as cross entropy is categorical so it expects a int type, and not a float type. However, if I move the cast to after the losses have been calculated, I get the original error again: “RuntimeError: Found dtype Long but expected Float”. There must be something I am misunderstanding when using backprop or the model, as I cannot see where a tensor with dtype Long is being used.
I have attached the training loop and the model I am using below, together with the attributes needed to understand the training loop. If more code is needed to provide context I will gladly supply it, but figured I should not make the post too long.
from einops import reduce import torch.nn as nn from torchvision.models import resnet18, ResNet18_Weights class ResNet18Model(nn.Module): def __init__(self, pretrained=True, freeze_backbone=True): super().__init__() weights = None if pretrained: weights = ResNet18_Weights.IMAGENET1K_V1 self.backbone = nn.Sequential(*list(resnet18(weights=weights).children())[:-2]) if freeze_backbone: for parameters in self.backbone.named_parameters(): parameters.requires_grad = False self.classification_head = nn.Linear(512, 120) self.box_head = nn.Linear(512, 4) def forward(self, x): backbone_features = self.backbone(x) backbone_features = reduce(backbone_features, 'b c h w -> b c', reduction='mean') classification = self.classification_head(backbone_features) box = self.box_head(backbone_features) return classification, box
Context for training loop:
self.loss_classification = F.cross_entropy self.loss_localization = F.mse_loss self.optimizer = SGD(model.parameters(), lr=0.001) self.learning_rate_scheduler = def base_lr_scheduler(t, T, lr): return lr
Training loop that produces error:
def train(self): for epoch in range(self.epochs): self.model.train() for x_batch, class_batch, box_batch in self.train_loader: box_batch = torch.squeeze(box_batch) # Update learning rate self.optimizer.param_groups['lr'] = self.learning_rate_scheduler( self.current_batch_index, self.total_batches, lr=self.optimizer.param_groups['lr']) # Forward pass prediction_class, prediction_box = self.model(x_batch) print('Prediction class shape:', prediction_class.shape) print('Prediction box shape:', prediction_box.shape) print('Prediction class type:', prediction_class.dtype) print('Prediction box type:', prediction_box.dtype) print('Batch class shape:', class_batch.shape) print('Batch box shape:', box_batch.shape) print('Batch class type:', class_batch.dtype) print('Batch box type:', box_batch.dtype) loss_class = self.loss_classification(prediction_class, class_batch) loss_box = self.loss_localization(prediction_box, box_batch) total_loss = loss_class + loss_box print('Loss class type', loss_class.dtype) print('Loss box type', loss_box.dtype) print('Total loss type', total_loss.dtype) class_batch = class_batch.float() box_batch = box_batch.float() print('Batch class type after .float():', class_batch.dtype) print('Batch box type after .float():', box_batch.dtype) # Backprop total_loss.backward() # Update model parameters self.optimizer.step() self.optimizer.zero_grad()
The code above returns the following when training the model:
Prediction class shape: torch.Size([64, 120]) Prediction box shape: torch.Size([64, 4]) Prediction class type: torch.float32 Prediction box type: torch.float32 Batch class shape: torch.Size() Batch box shape: torch.Size([64, 4]) Batch class type: torch.int64 Batch box type: torch.int64 Loss class type torch.float32 Loss box type torch.float32 Total loss type torch.float32 Batch class type after .float(): torch.float32 Batch box type after .float(): torch.float32 Traceback (most recent call last): File "localization_test.py", line 33, in <module> dog_localizer.train() File "/home/nmunch/computer_vision_project/dog_localization_project/dog_localization_utilities/dog_localization_utilities/localization.py", line 113, in train total_loss.backward() File "/home/nmunch/environments/computer_vision/lib/python3.8/site-packages/torch/_tensor.py", line 492, in backward torch.autograd.backward( File "/home/nmunch/environments/computer_vision/lib/python3.8/site-packages/torch/autograd/__init__.py", line 251, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass RuntimeError: Found dtype Long but expected Float
I feel like there is something obvious I am missing so I hope any of you can help me. Thanks in advance!