How to split input channels at start of model?

The start of my model looks like this:

def __init__(self, pretrained=False):
  super().__init__()
  self.base = nn.ModuleList([])

  # Stem layer
  self.base.append(ConvLayer(in_channels=4, out_channels=first_ch[0], kernel=3, stride=2))
  self.base.append(ConvLayer(in_channels=first_ch[0], out_channels=first_ch[1], kernel=3))
  self.base.append(nn.MaxPool2d(kernel_size=2, stride=2))

  # Rest of model definition goes here
  self.base.append(...)

def forward(self, x):
  out_branch =[]
  for i in range(len(self.base)-1):            
    x = self.base[i](x)
    out_branch.append(x)
  return out_branch

Notice that it takes 4 channels as input in the first conv layer. This input is an RGB tensor (first 3 channels) with an extra channel added by the data loader. I want to change this so I can first split the in_channels into 3 and 1 so that I can have 2 stems: one stem for the RGB channels, and a second stem for the 4th channel. I will then concat the output of the single channel stem at some later point in the network. What might be the best way to do this?

You could slice the inputs on your forward method and pass the chunks to each corresponding stem. Afterwards use torch.cat to concatenate the activations again.

1 Like

Thanks for the pointers. I think I made some progress with it, but have hit an issue. Here is my new code:

def __init__(self, pretrained=False):
  super().__init__()
  self.base = nn.ModuleList([])

  # Stem layer
  # RGB stem layer
  self.rgb_stem_a = ConvLayer(in_channels=3, out_channels=first_ch[0], kernel=3, stride=2, bias=False, ibn=True)
  self.rgb_stem_b = ConvLayer(in_channels=first_ch[0], out_channels=first_ch[1], kernel=3, ibn=False, sn=True)

  # single stem layer
  self.single_stem_a = ConvLayer(in_channels=1, out_channels=first_ch[0], kernel=3, stride=2, bias=False, ibn=True)
  self.single_stem_b = ConvLayer(in_channels=first_ch[0], out_channels=first_ch[1], kernel=3, ibn=False, sn=True)

  self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

  # Rest of model definition goes here
  self.base.append(...)

def forward(self, x):
  out_branch = []
  rgb_channels = x[:,[0,1,2]]
  single_channel = x[:,[3]]

  rgb = self.rgb_stem_a(rgb_channels)
  rgb = self.rgb_stem_b(rgb)

  single = self.single_stem_a(single_channel)
  single = self.single_stem_b(single)

  r = self.maxpool(rgb)
  g = self.maxpool(single)

  cat = torch.cat((r, g), dim=0)

  for i in range(len(self.base)-1):
    if i == 0:
      x = self.base[i](cat)
    else:
      x = self.base[i](x)
    out_branch.append(x)
  return out_branch

However, with this new code I get the following error:

File "/miniconda3/envs/test/lib/python3.9/site-packages/torch/nn/functional.py", 
line 3163, in binary_cross_entropy_with_logits
raise ValueError("Target size ({}) must be the same as input size 
({})".format(target.size(), input.size()))
ValueError: Target size (torch.Size([5, 1, 512, 512])) must be the same as 
input size (torch.Size([10, 1, 512, 512]))

I’m guessing this may be related to the way I am concatenating the two maxpool layers. Can you spot anything obvious that is wrong with my approach? Here’s the output of cat.shape:

torch.Size([10, 60, 128, 128])
torch.Size([10, 150, 128, 128])