Train Deeplabv3 on 4 channels images

Hi there, i want to train deeplabV3 on my own Dataset with 4 channels images. but i didn’t find any PyTorch implementation of deeplabV3 where i could change parameters and input channels number of the model to fit my (4channels) images .
How can i modify deeplabV3 to adapt it to my dataset?

torchvision provides deeplabv3 implementations here and you could manipulate the first conv layer as seen here:

model = models.segmentation.deeplabv3_resnet50(pretrained=False, progress=True, num_classes=21, aux_loss=None)

x = torch.randn(2, 3, 224, 224)
out = model(x)

model.backbone.conv1 = nn.Conv2d(4, 64, 7, 2, 3, bias=False)
x = torch.randn(2, 4, 224, 224)
out = model(x)
3 Likes