How to use 4-channel image for faster rcnn pytorch

want to use a 4-channel image for faster learning of rcnn. I use as backbone resnet50, I changed it to 4 channel images, what changes I have to do that to train all faster rcnn

You can replace the first conv layer. Then train ooooon :train:

import torchvision.models as models
model = models.resnet50(pretrained=True)
n_image_channels = 4
model.conv1 = nn.Conv2d(n_image_channels,

You can add a 1x1 convolution layer as the first layer in your network with a 4 channel input and 3 channel output.

1 Like

I make such changes for 4 channels but it doesn’t work, most likely I’m wrong somewhere but I can’t understand where

import torch
import torch.nn as nn
import torchvision
from torchvision.models import resnet50
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator

def get_resnet50(n_classes, n_channels=4):
    model = resnet50(pretrained=False)
    for p in model.parameters():
        p.requires_grad = False
    inft = model.fc.in_features
    model.fc = nn.Linear(in_features=inft, out_features=n_classes)
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    model.conv1 = nn.Conv2d(n_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)

    return model

def get_faster_rcnn(n_classes, n_channels=4):
    resnet = get_resnet50(n_classes, n_channels)
    resnet_features = list(resnet.children())[:-1]
    backbone = nn.Sequential(*resnet_features)

    # ResNet50 output channels is 2048
    backbone.out_channels = 2048

    anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),), aspect_ratios=((0.5, 1.0, 2.0),))

    roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=[0], output_size=7, sampling_ratio=2)

    model = FasterRCNN(backbone=backbone, num_classes=n_classes, rpn_anchor_generator=anchor_generator,

    return model

I don’t know, haven’t tried that out yet. Would help if you’d share the errors you’re getting.

A backbone for an object detector normally doesn’t utilize a fully connected layer (nor global average pooling?), but I don’t know if torchvision kinda hooks into this and fixes it

I’m running into the same problem, were you able to solve it?

EDIT: never mind, my problem had nothing to do with it. (Don’t wanna leave someone in the future hanging, my problem was that I divided the input image by 255 one to many times, so the pixel values were between 0. and ≈0.004)