Hi,
I inherited torchvision.models.detection.rpn.RegionProposalNetwork to add some customized features to my Faster R-CNN(also based on torchvision), and I found that when editing the anchor sizes of anchor generator, there will be an error:
Traceback (most recent call last):
File "/home/shiqian/zhangzh/FAFRCNN/train.py", line 17, in <module>
train()
File "/home/shiqian/zhangzh/FAFRCNN/train.py", line 10, in train
best_path = source_train()
File "/home/shiqian/zhangzh/FAFRCNN/source_train.py", line 78, in source_train
losses, detections, image_sizes, original_image_sizes = faster_rcnn_mp(images, targets)
File "/home/shiqian/.local/lib/python3.5/site-packages/torch/nn/modules/module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "/home/shiqian/zhangzh/FAFRCNN/model/faster_rcnn/faster_rcnn_model_parallel.py", line 213, in forward
proposals, proposal_losses = self.rpn(images, features, targets)
File "/home/shiqian/.local/lib/python3.5/site-packages/torch/nn/modules/module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "/home/shiqian/zhangzh/FAFRCNN/model/faster_rcnn/faster_rcnn_model_parallel.py", line 382, in forward
proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors) # Tensor at GPU4
File "/home/shiqian/.local/lib/python3.5/site-packages/torchvision/models/detection/_utils.py", line 168, in decode
rel_codes.reshape(sum(boxes_per_image), -1), concat_boxes
RuntimeError: shape '[1670004, -1]' is invalid for input of size 6779988
I just modified the initial Faster R-CNN initialization of
if rpn_anchor_generator is None:
anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
rpn_anchor_generator = AnchorGenerator(
anchor_sizes, aspect_ratios
)
to
if rpn_anchor_generator is None:
anchor_sizes = ((128,), (256,), (512,))
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
rpn_anchor_generator = AnchorGenerator(
anchor_sizes, aspect_ratios
)
as my GPU can’t handle so many anchors.
Anybody could help me out?