want to use a 4-channel image for faster learning of rcnn. I use as backbone resnet50, I changed it to 4 channel images, what changes I have to do that to train all faster rcnn
You can replace the first conv layer. Then train ooooon
import torchvision.models as models model = models.resnet50(pretrained=True) n_image_channels = 4 model.conv1 = nn.Conv2d(n_image_channels, model.conv1.out_channels, kernel_size=7, stride=2, padding=3, bias=False)
You can add a
1x1 convolution layer as the first layer in your network with a 4 channel input and 3 channel output.
I make such changes for 4 channels but it doesn’t work, most likely I’m wrong somewhere but I can’t understand where
import torch import torch.nn as nn import torchvision from torchvision.models import resnet50 from torchvision.models.detection import FasterRCNN from torchvision.models.detection.rpn import AnchorGenerator def get_resnet50(n_classes, n_channels=4): model = resnet50(pretrained=False) for p in model.parameters(): p.requires_grad = False inft = model.fc.in_features model.fc = nn.Linear(in_features=inft, out_features=n_classes) model.avgpool = nn.AdaptiveAvgPool2d(1) model.conv1 = nn.Conv2d(n_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) return model def get_faster_rcnn(n_classes, n_channels=4): resnet = get_resnet50(n_classes, n_channels) resnet_features = list(resnet.children())[:-1] backbone = nn.Sequential(*resnet_features) # ResNet50 output channels is 2048 backbone.out_channels = 2048 anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),), aspect_ratios=((0.5, 1.0, 2.0),)) roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=, output_size=7, sampling_ratio=2) model = FasterRCNN(backbone=backbone, num_classes=n_classes, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler) return model
I don’t know, haven’t tried that out yet. Would help if you’d share the errors you’re getting.
A backbone for an object detector normally doesn’t utilize a fully connected layer (nor global average pooling?), but I don’t know if torchvision kinda hooks into this and fixes it
I’m running into the same problem, were you able to solve it?
EDIT: never mind, my problem had nothing to do with it. (Don’t wanna leave someone in the future hanging, my problem was that I divided the input image by 255 one to many times, so the pixel values were between 0. and ≈0.004)