How to calculate validation loss for faster RCNN?

It can be found while using model.train(). But it is not the right method to use it under the model.train() since batch normalization, dropout, etc become deactivate in evaluation mode and not in train model. In evaluation (model.eval()) mode, it is unable to find the loss.

It is almost the same question as Compute validation loss for Faster RCNN . Pls help.

@ptrblck pls help me to overcome this issue.

Would it work, if you call .eval() only on all dropout and batchnorm layers, while the parent module is kept in the training state?

@ptrblck f I call model.eval(). I will be only able to predict the results. But unable to get validation loss.
I can get it as:

with torch.no_grad():
for image,target in val_loader:
…

But I can’t put model to evaluation state. Is there any way I could calculate validation loss. (Because if i do it in model.train() batch normalization and dropout will be active)

@ptrblck. I used torchvision model(fasterrcnn). So do I need to edit source code for that.

I’m not suggesting to call model.eval(), but .eval() only on dropout and batchnorm layers.
Why wouldn’t you be able to calculate the validation loss using this approach?

2 Likes

@ptrblck…Thanks…it sounds great…Let me try that way

Is there a recent working example that shows how to calculate validation loss without updating batch norm/and or dropout layers?

I’m attempting this atm:

@torch.no_grad()
def evaluate_loss(model, data_loader, device):
    val_loss = 0
    model.train()
    for images, targets in data_loader:
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)

        losses = sum(loss for loss in loss_dict.values())

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = utils.reduce_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
        val_loss += losses_reduced
  
  validation_loss = val_loss/ len(data_loader)    
  return validation_loss

Where I then place the function in this loop:

for epoch in range(args.num_epochs):
        # train for one epoch, printing every 10 iterations
        train_one_epoch(model, optimizer, train_data_loader, device, epoch, print_freq=10)
    
        # update the learning rate
        lr_scheduler.step()

        validation_loss = evaluate_loss(model, valid_data_loader, device=device)

        # evaluate on the test dataset
        evaluate(model, valid_data_loader, device=device)

However, I’m still not sure if this will cause issues. For example, I’ve been told and read everywhere that staying in model.train() would update batch norm and/or dropout layers during evaluation…

I only need to track validation loss to save the “best” model, i.e. the one with the lowest validation loss out of 100 epochs.

This is correct, which is why I suggested to call eval() on the batchnorm and dropout layers only in this topic. Did you try out this approach and was it not working?

Hi, I’m wondering where should these amendments be made if one is following the pytorch tutorial on object detection?

Call .eval() on the mentioned layers before starting the validation loop by iterating all modules.
E.g. something like this would work:

model = models.resnet50()

# check that all layers are in train mode
for name, module in model.named_modules():
    if hasattr(module, 'training'):
        print('{} is training {}'.format(name, module.training))

# set bn layers to eval
for module in model.modules():
    if isinstance(module, nn.BatchNorm2d):
        module.eval()

# bn layers are now in eval
for name, module in model.named_modules():
    if hasattr(module, 'training'):
        print('{} is training {}'.format(name, module.training))
1 Like

Hi,

Thanks for the reply.

So I amended my block of code to this:

@torch.no_grad()
def evaluate_loss(model, data_loader, device):
    val_loss = 0
    model.train()
    for images, targets in data_loader:
        # check that all layers are in train mode
        # for name, module in model.named_modules():
        #     if hasattr(module, 'training'):
        #         print('{} is training {}'.format(name, module.training))
        #         # set bn layers to eval
        for module in model.modules():
            if isinstance(module, torch.nn.BatchNorm2d):
                module.eval()
        # bn layers are now in eval
        for name, module in model.named_modules():
            if hasattr(module, 'training'):
                print('{} is training {}'.format(name, module.training))
                                
                
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)

        losses = sum(loss for loss in loss_dict.values())

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = utils.reduce_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
        val_loss += losses_reduced
    
    validation_loss = val_loss/ len(data_loader)    
    return validation_loss

This is an example snippet of the output:

backbone.body.layer1.1.conv1 is training True
backbone.body.layer1.1.bn1 is training True
backbone.body.layer1.1.conv2 is training True
backbone.body.layer1.1.bn2 is training True
backbone.body.layer1.1.conv3 is training True
backbone.body.layer1.1.bn3 is training True

Is this the expected behaviour I should be getting to ensure validation losses are correct? If so, should I also be adding a similar block for dropout?

No, it’s not expected since all layers are still in training mode.
Check my code snippet and compare what might be different in your model. I guess you might be using nn.BatchNorm1d or ...3d layers as the if condition seems to fail.

I’m not really sure tbh and just tried all the batch norms as provided by the torch.nn module. They all seem to stay true for training…

I don’t know if this makes a difference but my model is originally set up as this

#set up model 

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True, pretrained_backbone=True)
num_classes = 2  # 1 class (mitosis) + background

# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features

# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

model.to(device)
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
#lr_scheduler = None

EDIT

I just tried it with nn.Conv2d and the expected behaviour is there… just not with batch norm :confused:

EDIT 2:

So I printed out the modules and it appears this ResNet uses FrozenBatchNorm2d. The problem I have now is that torchvision.ops.FrozenBatchNorm2d comes out with this error:

AttributeError: module 'torchvision.ops' has no attribute 'FrozenBatchNorm2d'

EDIT 3:

It’s not shown in the documentation but the correct module is torchvision.ops.misc.FrozenBatchNorm2d

Since you are using FrozenBatchNorm2d I don’t think you would need to call eval() on it, as all parameters and buffers are already frozen. From the docs:

BatchNorm2d where the batch statistics and the affine parameters are fixed

1 Like

I’ve been told to use the forward function from generalized_rcnn and iterate through my dataloader, switching everything to .eval() mode to determine losses, would this be appropriate?:

from typing import Tuple, List, Dict, Optional
import torch
from torch import Tensor
from collections import OrderedDict
from torchvision.models.detection.roi_heads import fastrcnn_loss
from torchvision.models.detection.rpn import concat_box_prediction_layers
def eval_forward(model, images, targets):
    # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
    """
    Args:
        images (list[Tensor]): images to be processed
        targets (list[Dict[str, Tensor]]): ground-truth boxes present in the image (optional)
    Returns:
        result (list[BoxList] or dict[Tensor]): the output from the model.
            It returns list[BoxList] contains additional fields
            like `scores`, `labels` and `mask` (for Mask R-CNN models).
    """
    model.eval()

    original_image_sizes: List[Tuple[int, int]] = []
    for img in images:
        val = img.shape[-2:]
        assert len(val) == 2
        original_image_sizes.append((val[0], val[1]))

    images, targets = model.transform(images, targets)

    # Check for degenerate boxes
    if targets is not None:
        for target_idx, target in enumerate(targets):
            boxes = target["boxes"]
            degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
            if degenerate_boxes.any():
                # print the first degenerate box
                bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
                degen_bb: List[float] = boxes[bb_idx].tolist()
                raise ValueError(
                    "All bounding boxes should have positive height and width."
                    f" Found invalid box {degen_bb} for target at index {target_idx}."
                )

    features = model.backbone(images.tensors)
    if isinstance(features, torch.Tensor):
        features = OrderedDict([("0", features)])
    model.rpn.training=True
    #model.roi_heads.training=True


    #####proposals, proposal_losses = model.rpn(images, features, targets)
    features_rpn = list(features.values())
    objectness, pred_bbox_deltas = model.rpn.head(features_rpn)
    anchors = model.rpn.anchor_generator(images, features_rpn)

    num_images = len(anchors)
    num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness]
    num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors]
    objectness, pred_bbox_deltas = concat_box_prediction_layers(objectness, pred_bbox_deltas)
    # apply pred_bbox_deltas to anchors to obtain the decoded proposals
    # note that we detach the deltas because Faster R-CNN do not backprop through
    # the proposals
    proposals = model.rpn.box_coder.decode(pred_bbox_deltas.detach(), anchors)
    proposals = proposals.view(num_images, -1, 4)
    proposals, scores = model.rpn.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)

    proposal_losses = {}
    assert targets is not None
    labels, matched_gt_boxes = model.rpn.assign_targets_to_anchors(anchors, targets)
    regression_targets = model.rpn.box_coder.encode(matched_gt_boxes, anchors)
    loss_objectness, loss_rpn_box_reg = model.rpn.compute_loss(
        objectness, pred_bbox_deltas, labels, regression_targets
    )
    proposal_losses = {
        "loss_objectness": loss_objectness,
        "loss_rpn_box_reg": loss_rpn_box_reg,
    }

    #####detections, detector_losses = model.roi_heads(features, proposals, images.image_sizes, targets)
    image_shapes = images.image_sizes
    proposals, matched_idxs, labels, regression_targets = model.roi_heads.select_training_samples(proposals, targets)
    box_features = model.roi_heads.box_roi_pool(features, proposals, image_shapes)
    box_features = model.roi_heads.box_head(box_features)
    class_logits, box_regression = model.roi_heads.box_predictor(box_features)

    result: List[Dict[str, torch.Tensor]] = []
    detector_losses = {}
    loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
    detector_losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
    boxes, scores, labels = model.roi_heads.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
    num_images = len(boxes)
    for i in range(num_images):
        result.append(
            {
                "boxes": boxes[i],
                "labels": labels[i],
                "scores": scores[i],
            }
        )
    detections = result
    detections = model.transform.postprocess(detections, images.image_sizes, original_image_sizes)  # type: ignore[operator]
    model.rpn.training=False
    model.roi_heads.training=False
    losses = {}
    losses.update(detector_losses)
    losses.update(proposal_losses)
    return losses, detections

And then testing it out:

import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

# load a model pre-trained on COCO
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

# replace the classifier with a new one, that has
# num_classes which is user-defined
num_classes = 2  # 1 class (person) + background
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
losses, detections = eval_forward(model,torch.randn([1,3,300,300]),[{'boxes':torch.tensor([[100,100,200,200]]),'labels':torch.tensor([0])}])

{'loss_classifier': tensor(0.6594, grad_fn=<NllLossBackward0>),
'loss_box_reg': tensor(0., grad_fn=<DivBackward0>),
 'loss_objectness': tensor(0.5108, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>),
 'loss_rpn_box_reg': tensor(0.0160, grad_fn=<DivBackward0>)}

By using the above code I evaluate the validation loss by looking at the loss_classifier:

def evaluate_loss(model, data_loader, device):
    val_loss = 0
    with torch.no_grad():
      for images, targets in data_loader:
          images = list(image.to(device) for image in images)
          targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
          losses, detections = eval_forward(model, images, targets)
          #print(losses['loss_classifier'])

          val_loss += losses['loss_classifier']
          
    validation_loss = val_loss/ len(data_loader)    
    return validation_loss