I’ve been told to use the forward function from generalized_rcnn
and iterate through my dataloader, switching everything to .eval()
mode to determine losses, would this be appropriate?:
from typing import Tuple, List, Dict, Optional
import torch
from torch import Tensor
from collections import OrderedDict
from torchvision.models.detection.roi_heads import fastrcnn_loss
from torchvision.models.detection.rpn import concat_box_prediction_layers
def eval_forward(model, images, targets):
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
"""
Args:
images (list[Tensor]): images to be processed
targets (list[Dict[str, Tensor]]): ground-truth boxes present in the image (optional)
Returns:
result (list[BoxList] or dict[Tensor]): the output from the model.
It returns list[BoxList] contains additional fields
like `scores`, `labels` and `mask` (for Mask R-CNN models).
"""
model.eval()
original_image_sizes: List[Tuple[int, int]] = []
for img in images:
val = img.shape[-2:]
assert len(val) == 2
original_image_sizes.append((val[0], val[1]))
images, targets = model.transform(images, targets)
# Check for degenerate boxes
if targets is not None:
for target_idx, target in enumerate(targets):
boxes = target["boxes"]
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
if degenerate_boxes.any():
# print the first degenerate box
bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
degen_bb: List[float] = boxes[bb_idx].tolist()
raise ValueError(
"All bounding boxes should have positive height and width."
f" Found invalid box {degen_bb} for target at index {target_idx}."
)
features = model.backbone(images.tensors)
if isinstance(features, torch.Tensor):
features = OrderedDict([("0", features)])
model.rpn.training=True
#model.roi_heads.training=True
#####proposals, proposal_losses = model.rpn(images, features, targets)
features_rpn = list(features.values())
objectness, pred_bbox_deltas = model.rpn.head(features_rpn)
anchors = model.rpn.anchor_generator(images, features_rpn)
num_images = len(anchors)
num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness]
num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors]
objectness, pred_bbox_deltas = concat_box_prediction_layers(objectness, pred_bbox_deltas)
# apply pred_bbox_deltas to anchors to obtain the decoded proposals
# note that we detach the deltas because Faster R-CNN do not backprop through
# the proposals
proposals = model.rpn.box_coder.decode(pred_bbox_deltas.detach(), anchors)
proposals = proposals.view(num_images, -1, 4)
proposals, scores = model.rpn.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)
proposal_losses = {}
assert targets is not None
labels, matched_gt_boxes = model.rpn.assign_targets_to_anchors(anchors, targets)
regression_targets = model.rpn.box_coder.encode(matched_gt_boxes, anchors)
loss_objectness, loss_rpn_box_reg = model.rpn.compute_loss(
objectness, pred_bbox_deltas, labels, regression_targets
)
proposal_losses = {
"loss_objectness": loss_objectness,
"loss_rpn_box_reg": loss_rpn_box_reg,
}
#####detections, detector_losses = model.roi_heads(features, proposals, images.image_sizes, targets)
image_shapes = images.image_sizes
proposals, matched_idxs, labels, regression_targets = model.roi_heads.select_training_samples(proposals, targets)
box_features = model.roi_heads.box_roi_pool(features, proposals, image_shapes)
box_features = model.roi_heads.box_head(box_features)
class_logits, box_regression = model.roi_heads.box_predictor(box_features)
result: List[Dict[str, torch.Tensor]] = []
detector_losses = {}
loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
detector_losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
boxes, scores, labels = model.roi_heads.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
num_images = len(boxes)
for i in range(num_images):
result.append(
{
"boxes": boxes[i],
"labels": labels[i],
"scores": scores[i],
}
)
detections = result
detections = model.transform.postprocess(detections, images.image_sizes, original_image_sizes) # type: ignore[operator]
model.rpn.training=False
model.roi_heads.training=False
losses = {}
losses.update(detector_losses)
losses.update(proposal_losses)
return losses, detections
And then testing it out:
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
# load a model pre-trained on COCO
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
# replace the classifier with a new one, that has
# num_classes which is user-defined
num_classes = 2 # 1 class (person) + background
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
losses, detections = eval_forward(model,torch.randn([1,3,300,300]),[{'boxes':torch.tensor([[100,100,200,200]]),'labels':torch.tensor([0])}])
{'loss_classifier': tensor(0.6594, grad_fn=<NllLossBackward0>),
'loss_box_reg': tensor(0., grad_fn=<DivBackward0>),
'loss_objectness': tensor(0.5108, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>),
'loss_rpn_box_reg': tensor(0.0160, grad_fn=<DivBackward0>)}