YoloV8 Gradients of prediction scores w.r.t input imgs are NaN

Hello all,

I am working on trying to generate some attributions maps for YoloV8. I was able to achieve this on the YoloV7 relatively easily.

Trying on YoloV8 I seem to always get a tensor full of NaNs, which is odd as the general process I conducting is similar to how I did it in YoloV7. The architecture between them has not changed much, but the code has as YoloV8 is Ultralytics while YoloV7 was a different company.

This is what I’ve tried:

  • Use detect_anomaly to find the root cause of the nan’s during the backdrop. Torch states: 'ConvolutionBackward0' returned nan values in its 0th output.. Which is the initial layer. I checked the inputs for NaNs as well and they are clean.
  • Double check gradient creation method to make sure its correct per my YoloV7 and common vanilla gradient attribution methods.
  • Debug the YoloV8 Forward pass for any weird manipulation of the input data. None that I could find.
  • Simply the code down to a minimal example to verify and reproduce the issue.

My Code:

from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.nn.tasks import DetectionModel
from ultralytics.utils import (RANK)
import torch

class CustomModel(DetectionModel):
    def loss(self, batch, preds=None):
        if not hasattr(self, 'criterion'):
            self.criterion = self.init_criterion()

        imgs = batch['img'].requires_grad_(True)
        preds = self.forward(imgs) if preds is None else preds
        loss = self.criterion(preds, batch)
        pred_scores = self.get_pred_scores(preds)
        # Using autograd so I can change w.r.t target, so can test w.r.t loss, obj_loss, etc..
        gradients = torch.autograd.grad(pred_scores, 
                                    imgs, 
                                    grad_outputs=torch.ones_like(pred_scores),
                                    retain_graph=True)[0]
        # The rest of the attribution method would go here (Assume for now its just vanilla gradient attribution

        return loss

    def get_pred_scores(self, preds):
        # This is ripped from the v8DetectionLoss
        feats = preds[1] if isinstance(preds, tuple) else preds
        _, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.model[-1].no, -1) for xi in feats], 2).split(
            (self.model[-1].reg_max * 4, self.model[-1].nc), 1)
        pred_scores = pred_scores.permute(0, 2, 1).contiguous()
        return pred_scores 

class CustomTrainer(DetectionTrainer):
    def get_model(self, cfg=None, weights=None, verbose=True):
        """Return a YOLO detection model."""
        model = CustomModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
        if weights:
            model.load(weights)
        return model

Assume this CustomTrainer is passed into the model.train method and ran.
Luckily way less code then in YoloV7 as Ultralytics does a good job with abstraction!

UPDATE:

After more investigation, the gradients seem to turn to NaN after layer 16. So the backdrop is able to get through 20-16, and then all grads after that become NaN. This would explain why the gradient output is all nan, as the gradients before it all are also NaN to a point. The real question is why? Is the gradient exploding? Is the model frozen past this layer? I will keep investigating, but I still urge anyone who knows what is happening to reach out!

After a few more days of experimentation, I ran the model on a CPU device and the gradients were populated successfully! Upon further investigation, it seems that the AMP or Automated Mixed Precision is only supported on specific GPUs. Even though YOLOv8 checks for compatibility, it doesn’t seem perfect. After disabling AMP in the training args, I successfully got populated gradients, and was able to use a GPU device instead!

Just disable AMP by either passing amp=false in the model.train(). Or pass it to the trainer manually.

results = model.train(data='coco128.yaml', epochs=100, imgsz=640, amp=False)

First of all thanks for sharing your experience and knowledge.
I am facing problem also. I have disable AMP by passing amp=false
But i am using GPU(T4-tesla) and facing this error:

RuntimeError Traceback (most recent call last)
in <cell line: 3>()
1 # Backward pass to get the gradient
2 model.zero_grad()
----> 3 loss.backward()
4 data_grad = image.grad.data

1 frames
/usr/local/lib/python3.10/dist-packages/torch/autograd/init.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
249 # some Python versions print out the first line of a multi-line function
250 # calls in the traceback and some print out the last line
→ 251 Variable.execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
252 tensors,
253 grad_tensors
,

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Is this problem is related with my GPU ? What type of GPU can i use ?

No, the used GPU is irrelevant as the error is raised by Autograd. You might have either disabled gradient calculation globally, might have detached a forward activation, or used non-differentiable operations.

What does cls_id.grad_fn return? If it’s None, the tensor is not attached to any computation graph.

Thanks @ptrblck

yes , cls_id.grad_fn returns none.
what to do now? How can i calculate the gradient ? Please give me suggestion.

I’m not familiar with your model, but you should check my previous post and in particular the model.predict call to narrow down where the computation graph is detached or if gradient calculation is disabled in predict.