Backpropagation not working for YOLO V1 Model

Hello! New poster, so let me know if I need to change/add anything.

I am trying to create a simplified version of the yolov1 model for face detection.

I have created a custom Yolo module. The input is a n x 3 x 300 x 300 tensor of images and output is a n x 5 x 7 x 7 tensor of 49 regions and corresponding x, y, w, and h (only one predicted box per region for added simplicity).

class YOLO(nn.Module):
  def __init__(self):
    super(YOLO, self).__init__()
    self.mp = nn.MaxPool2d(kernel_size=2, stride=2)
    self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
    self.conv2 = nn.Conv2d(64, 192, kernel_size=3, padding=1)
    self.conv3 = nn.Conv2d(192, 256, kernel_size=3, padding=1)
    self.conv4 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
    self.conv5 = nn.Conv2d(512, 1024, kernel_size=3, padding=1)
    self.conv6 = nn.Conv2d(1024, 1024, kernel_size=3, padding=1)
    self.conv7 = nn.Conv2d(1024, 1024, kernel_size=3, padding=1)
    self.conv8 = nn.Conv2d(1024, 1024, kernel_size=3, padding=1)
    self.conv9 = nn.Conv2d(1024, 1024, kernel_size=3, padding=1)
    self.conv10 = nn.Conv2d(1024, 1024, kernel_size=3)
    
    self.relu_downsize = nn.Sequential(
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.LeakyReLU(.01, inplace=False)
    )
    
    self.linear_layers = nn.Sequential(
        nn.Flatten(),
        nn.Linear(7 * 7 * 1024,  4096),
        nn.LeakyReLU(.01, inplace=False),
        nn.Linear(4096, 7 * 7 * 5),
        nn.LeakyReLU(.01, inplace=False)
    )

    self.x_steps = torch.tile(torch.arange(start=0, end=7, step=1).view(1, 1, 7), (1, 7, 1))
    self.y_steps = torch.tile(torch.arange(start=0, end=7, step=1).view(1, 7, 1), (1, 1, 7))
    
  def forward(self, x):
    x = self.relu_downsize(self.conv1(x))
    x = self.relu_downsize(self.conv2(x))
    x = self.relu_downsize(self.conv3(x))
    x = F.leaky_relu(self.conv4(x), .01)
    x = self.relu_downsize(self.conv5(x))
    x = F.leaky_relu(self.conv6(x), .01)
    x = self.relu_downsize(self.conv7(x))
    x = F.leaky_relu(self.conv8(x))
    x = F.leaky_relu(self.conv9(x))
    x = F.leaky_relu(self.conv10(x))
    x = torch.sigmoid(self.linear_layers(x))
    x = x.view(-1, 5, 7, 7)
    x[:, :2, :, :] = x[:, :2, :, :] / 7
    x[:, 0, :, :] += torch.tile(self.x_steps.clone(), (x.shape[0], 1, 1))/7
    x[:, 1, :, :] += torch.tile(self.y_steps.clone(), (x.shape[0], 1, 1))/7

    return x

I have also created a YoloLoss module.

class YoloLoss(nn.Module):
  def __init__(self):
    super().__init__()

    self.mse_loss = nn.MSELoss()

    self.lambda_coord = .5
    self.lambda_noobj = .5

  def forward(self, yhat, y):
    """ localization + classification YOLO loss
    yhat is a (n x 5 x 7 x 7) tensor
    y is a list of size n, each element being a (# faces x 4) tensor"""

    num_images = len(y)

    err_loc = 0
    err_size = 0
    err_inclass = 0
    err_outclass = 0

    for img_num, act_faces in enumerate(y):
      responsible_regions = []

      for face_num, act_face in enumerate(act_faces):
        face_region = torch.floor(act_face[:2] * 7).type(torch.LongTensor).detach()

        pred_loc = yhat[img_num, :2, face_region[0], face_region[1]]
        pred_size = yhat[img_num, 2:4, face_region[0], face_region[1]]
        pred_prob = yhat[img_num, 4, face_region[0], face_region[1]]
        responsible_regions.append(face_region)

        err_loc += self.mse_loss(pred_loc, act_face[:2])
        err_size += self.mse_loss(torch.sqrt(pred_size), torch.sqrt(act_face[2:4]))
        err_inclass += self.mse_loss(pred_prob, torch.ones_like(pred_prob))
        err_outclass -= self.mse_loss(pred_prob, torch.zeros_like(pred_prob))

      err_outclass += self.mse_loss(yhat[img_num, 4, :, :], torch.zeros_like(yhat[img_num, 4, :, :]))

    return err_loc + err_size + err_inclass + err_outclass

When I run an example through the model, I get a good output tensor. When I compare that output to the actual with YoloLoss, I get a value that looks right. When I do loss.backward(), I then get this error:

/usr/local/lib/python3.7/dist-packages/torch/autograd/__init__.py:175: UserWarning: Error detected in MseLossBackward0. Traceback of forward call that caused the error:
  File "/usr/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/usr/local/lib/python3.7/dist-packages/traitlets/config/application.py", line 846, in launch_instance
    app.start()
  File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelapp.py", line 499, in start
    self.io_loop.start()
  File "/usr/local/lib/python3.7/dist-packages/tornado/platform/asyncio.py", line 132, in start
    self.asyncio_loop.run_forever()
  File "/usr/lib/python3.7/asyncio/base_events.py", line 541, in run_forever
    self._run_once()
  File "/usr/lib/python3.7/asyncio/base_events.py", line 1786, in _run_once
    handle._run()
  File "/usr/lib/python3.7/asyncio/events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
  File "/usr/local/lib/python3.7/dist-packages/tornado/platform/asyncio.py", line 122, in _handle_events
    handler_func(fileobj, events)
  File "/usr/local/lib/python3.7/dist-packages/tornado/stack_context.py", line 300, in null_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 452, in _handle_events
    self._handle_recv()
  File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 481, in _handle_recv
    self._run_callback(callback, msg)
  File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 431, in _run_callback
    callback(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/tornado/stack_context.py", line 300, in null_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 283, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 233, in dispatch_shell
    handler(stream, idents, msg)
  File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 399, in execute_request
    user_expressions, allow_stdin)
  File "/usr/local/lib/python3.7/dist-packages/ipykernel/ipkernel.py", line 208, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/usr/local/lib/python3.7/dist-packages/ipykernel/zmqshell.py", line 537, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2718, in run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2822, in run_ast_nodes
    if self.run_code(code, result):
  File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2882, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-41-a13bfa0cb0de>", line 44, in <module>
    loss = loss_func(preds, boxes_all)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "<ipython-input-41-a13bfa0cb0de>", line 34, in forward
    err_size += self.mse_loss(torch.sqrt(pred_size), torch.sqrt(act_face[2:4]))
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/loss.py", line 529, in forward
    return F.mse_loss(input, target, reduction=self.reduction)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py", line 3262, in mse_loss
    return torch._C._nn.mse_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
 (Triggered internally at  ../torch/csrc/autograd/python_anomaly_mode.cpp:104.)
  allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-41-a13bfa0cb0de> in <module>()
     43 loss_func = YoloLoss()
     44 loss = loss_func(preds, boxes_all)
---> 45 loss.backward()

1 frames
/usr/local/lib/python3.7/dist-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
    361                 create_graph=create_graph,
    362                 inputs=inputs)
--> 363         torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
    364 
    365     def register_hook(self, hook):

/usr/local/lib/python3.7/dist-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    173     Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    174         tensors, grad_tensors_, retain_graph, create_graph, inputs,
--> 175         allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass
    176 
    177 def grad(

RuntimeError: Found dtype Double but expected Float

I am not sure how to investigate the problematic tensor.

I think it’s due to the data type of your labels which might be tensor double/long while having your model declared with float tensor.
Could you replace it with float tensors and check if it solves the issue?

Unfortunately face_region has to be a long tensor because it is used as indices for which prediction region to attribute the predicted face to. See how pred_loc, pred_size, and pred_prob are calculated. Those values depend on the attributable region.

If I make it a float tensor, the loss function can’t run at all.

Right I didn’t see that.
How is act-face defined? I assume that’s your label having declared in double?

In this case, act_faces is a num_faces x 4 tensor (one of these tensors for each image).
act_face is just a vector tensor of length 4 (x, y, w, and h for the actual face bounding box of the image)

If I use the following print statements:

print(type(y))

print(type(y[0]))
print(y[0].shape)
print(y[0].dtype)

print(type(y[0][0]))
print(y[0][0].shape)
print(y[0][0].dtype)

I get the following output:

<class 'list'>

<class 'torch.Tensor'>
torch.Size([2, 4])
torch.float64

<class 'torch.Tensor'>
torch.Size([4])
torch.float64

This means there are 2 actual faces in the image

This is the image and the actual boxes:
image

Hmm, that looks good to me.
I would want to check if the input/output/model tensor has also the same type but I don’t see any reason why that shouldn’t be been the case.

yolo_model = YOLO()
print(type(imgs))
print(imgs.dtype)
print(imgs.shape)

preds = yolo_model(imgs)
print(type(preds))
print(preds.dtype)
print(preds.shape)

Gives the following:

<class 'torch.Tensor'>
torch.float32
torch.Size([4, 3, 300, 300])

<class 'torch.Tensor'>
torch.float32
torch.Size([4, 5, 7, 7])

Ok so that looks good too, the weights in Pytorch are initialised as float32, thus the output as float32 as well.
That might have been the case, because your bboxes are defined as float64 instead, so there is a miss match. I’d try to transform the targets to float32 and see what happens.

What’s weird is though that the error isn’t thrown during the forward pass.

That doesn’t seem to be helping either :confused:

I found the error.

When I define the YOLO model, it 1) self assignment (i.e. x = x + y) and 2) tensors with no gradient attached.

self.x_steps and self.y_steps need to be autograd tensors (torch.autograd.Variable(self.x_steps), for example). Also, I don’t know why this is needed, but the last three steps of forward were changed to a new variable on each line and that fixed it.

See below for the full new YOLO model.

class YOLO(nn.Module):
  def __init__(self):
    super(YOLO, self).__init__()
    self.mp = nn.MaxPool2d(kernel_size=2, stride=2)
    self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
    self.conv2 = nn.Conv2d(64, 192, kernel_size=3, padding=1)
    self.conv3 = nn.Conv2d(192, 256, kernel_size=3, padding=1)
    self.conv4 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
    self.conv5 = nn.Conv2d(512, 1024, kernel_size=3, padding=1)
    self.conv6 = nn.Conv2d(1024, 1024, kernel_size=3, padding=1)
    self.conv7 = nn.Conv2d(1024, 1024, kernel_size=3, padding=1)
    self.conv8 = nn.Conv2d(1024, 1024, kernel_size=3, padding=1)
    self.conv9 = nn.Conv2d(1024, 1024, kernel_size=3, padding=1)
    self.conv10 = nn.Conv2d(1024, 1024, kernel_size=3)
    
    self.relu_downsize = nn.Sequential(
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.LeakyReLU(.01, inplace=False)
    )
    
    self.linear_layers = nn.Sequential(
        nn.Flatten(),
        nn.Linear(7 * 7 * 1024,  4096),
        nn.LeakyReLU(.01, inplace=False),
        nn.Linear(4096, 7 * 7 * 5),
        nn.LeakyReLU(.01, inplace=False)
    )


    self.scale_factor = torch.ones(1, 5, 7, 7)
    self.scale_factor[:, :2, :, :] = self.scale_factor[:, :2, :, :] / 7
    x_steps = torch.tile(torch.arange(start=0, end=7, step=1).view(1, 1, 7), (1, 7, 1)) / 7
    y_steps = torch.tile(torch.arange(start=0, end=7, step=1).view(1, 7, 1), (1, 1, 7)) / 7
    no_steps = torch.zeros(1, 7, 7)
    self.steps = torch.stack([x_steps, y_steps, no_steps, no_steps, no_steps], axis=1)

  def forward(self, x):
    """ There are 49 "pixels", or output regions where faces will be predicted
    Each output region will have the following dimensions as channels: 
      - Center coordinate x (as a % of region)
      - Center coordinate y (as a % of region)
      - Width (as a % of image)
      - Height (as a & of region)
      - Pr(center of face is contained in this region)
    Output should then be n x 5 x 7 x 7
    Notice that all of the above are limited to [0, 1], so sigmoid is applied
    """

    N = x.shape[0]
    scale_factor = torch.autograd.Variable(torch.tile(self.scale_factor.clone(), (N, 1, 1, 1))).to(x.device)
    steps = torch.autograd.Variable(torch.tile(self.steps.clone(), (N, 1, 1, 1))).to(x.device)
    
    x = self.relu_downsize(self.conv1(x))
    x = self.relu_downsize(self.conv2(x))
    x = self.relu_downsize(self.conv3(x))
    x = F.leaky_relu(self.conv4(x), .01)
    x = self.relu_downsize(self.conv5(x))
    x = F.leaky_relu(self.conv6(x), .01)
    x = self.relu_downsize(self.conv7(x))
    x = F.leaky_relu(self.conv8(x))
    x = F.leaky_relu(self.conv9(x))
    x = F.leaky_relu(self.conv10(x))
    x = torch.sigmoid(self.linear_layers(x))
    x = x.view(-1, 5, 7, 7)
    x_scaled = x * scale_factor
    x_out = x + steps

    return x_out

I also changed the structure of the true bounding boxes to be n x 5 x 7 x 7 tensors instead of lists so that I could rewrite the following loss function:

class YoloLoss(nn.Module):
  def __init__(self):
    super().__init__()

    self.mse_loss = nn.MSELoss()

    self.lambda_coord = torch.autograd.Variable(torch.tensor(.5))
    self.lambda_noobj = torch.autograd.Variable(torch.tensor(.5))

  def forward(self, yhat, y):
    """ localization + classification YOLO loss
    yhat is a (n x 5 x 7 x 7) tensor
    y is a (n x 5 x 7 x 7) tensor"""

    region_yhat = yhat * torch.tile(y[:, 4, :, :].unsqueeze(1), (1, 5, 1, 1))
    nonregion_yhat = (1 - yhat) * torch.tile((1 - y[:, 4, :, :].unsqueeze(1)), (1, 5, 1, 1))

    err_loc = self.mse_loss(region_yhat[:, :2, :, :], y[:, :2, :, :])
    err_size = self.mse_loss(torch.sqrt(region_yhat[:, 2:4, :, :]), torch.sqrt(y[:, 2:4, :, :]))
    err_inclass = self.mse_loss(region_yhat[:, 4, :, :], y[:, 4, :, :])
    err_outclass = self.mse_loss(nonregion_yhat[:, 4, :, :], 1 - y[:, 4, :, :])

    return err_loc + err_size + err_inclass + err_outclass