Need help debugging issue on fasterrcnn_mobilenet


For the last couple of weeks, I have been really struggling to debug this new problem when training my detection model. Couple of weeks before the training worked great, but now all of a sudden I get this new unheard error.

I make my model and train overall like this but still get an error.

backbone = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(pretrained=True)
backbone.roi_heads.box_predictor.cls_score.out_features = len(classes) 
backbone.roi_heads.box_predictor.bbox_pred.out_features = 4 * (len(classes))

for epoch in range(epochs):
      train_one_epoch(net, optimizer, train_loader, device, epoch, print_freq=10)
      evaluate(net, test_loader, device=device)
    print("Time for Total Training {:0.2f}".format(time.time() - start_time))

    return net

I get some weird error that look like this:

 26     for epoch in range(epochs):

—> 27 train_one_epoch(net, optimizer, train_loader, device, epoch, print_freq=10)
28 evaluate(net, test_loader, device=device)

/content/ in train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq)
44 if not math.isfinite(loss_value):
45 print(“Loss is {}, stopping training”.format(loss_value))
—> 46 print(loss_dict_reduced)
47 sys.exit(1)

/usr/local/lib/python3.7/dist-packages/torch/ in backward(self, gradient, retain_graph, create_graph, inputs)
253 create_graph=create_graph,
254 inputs=inputs)
→ 255 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
257 def register_hook(self, hook):

/usr/local/lib/python3.7/dist-packages/torch/autograd/ in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
147 Variable.execution_engine.run_backward(
148 tensors, grad_tensors
, retain_graph, create_graph, inputs,
→ 149 allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag

RuntimeError: Found dtype Double but expected Float.

I also have boxes and labels encoded like this boxes = torch.as_tensor(boxes, dtype = torch.float32) labels = torch.as_tensor(labels, dtype = torch.int64) and my images are float tensors.

How do I get rid of this runtime error?
For all of my code with my data class, imported libs, and train check over here

Thanks for the help,

Could you remove the torch.set_default_dtype line of code and explicitly cast the tensors to the desired dtype?

Thanks for the help. Explicitly casting tensors to desired dtype helped.