Extract features from Mask R-CNN

I am following [1] to extract the features of the different layers. The Mask R-CNN model uses a resnet50 backbone, and there I want to add the feature extractors.

    train_nodes, eval_nodes = get_graph_node_names(model.backbone)

    nodes = {
        "body.conv1": "layer1",
        "body.maxpool": "layer2",
    }
    backbone = create_feature_extractor(model.backbone, return_nodes=nodes)
    model.backbone = backbone
    model.eval()
    out = model(img)

When executing this, the following error is thrown:

  File "C:\project/ml/visualize_model.py", line 184, in <module>
    out = model(img)
  File "C:\project\venv\lib\site-packages\torch\nn\modules\module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\project\venv\lib\site-packages\torchvision\models\detection\generalized_rcnn.py", line 98, in forward
    proposals, proposal_losses = self.rpn(images, features, targets)
  File "C:\project\venv\lib\site-packages\torch\nn\modules\module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\project\venv\lib\site-packages\torchvision\models\detection\rpn.py", line 341, in forward
    objectness, pred_bbox_deltas = self.head(features)
  File "C:\project\venv\lib\site-packages\torch\nn\modules\module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\project\venv\lib\site-packages\torchvision\models\detection\rpn.py", line 50, in forward
    t = F.relu(self.conv(feature))
  File "C:\project\venv\lib\site-packages\torch\nn\modules\module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\project\venv\lib\site-packages\torch\nn\modules\conv.py", line 447, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "C:\project\venv\lib\site-packages\torch\nn\modules\conv.py", line 443, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Given groups=1, weight of size [256, 256, 3, 3], expected input[1, 64, 400, 512] to have 256 channels, but got 64 channels instead

[1] Feature extraction for model inspection — Torchvision main documentation