How to format forward function for multiple inputs

I’m trying to get a U-Net model to take multiple inputs (8 separate audio spectrograms of torch.Size([1, 1024, 160])) and give a single output (a stereo audio mixture of the 8 tracks of torch.Size([2, 1024, 160])). I’m unsure how to write out the forward function of the net for my purpose. My DataLoader appears to be implemented correctly (with batch_size = 1):

input1, input2, input3, input4, input5, input6, input7, input8, target = next(iter(train_loader))

input1.shape
>>> torch.Size([1, 1, 1024, 160])

target.shape
>>> torch.Size([1, 2, 1024, 160])

My model architecture is:

def convrelu(in_channels, out_channels, kernel, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
        nn.ReLU(inplace=True),
    )


class ResNetUNet(nn.Module):
    def __init__(self, n_class):
        super().__init__()

        self.base_model = models.resnet18(pretrained=True)
        self.base_layers = list(self.base_model.children())

        self.layer0 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False),
            *self.base_layers[1:3]) # size=(N, 64, x.H/2, x.W/2)
        self.layer0_1x1 = convrelu(64, 64, 1, 0)
        self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4)
        self.layer1_1x1 = convrelu(64, 64, 1, 0)
        self.layer2 = self.base_layers[5]  # size=(N, 128, x.H/8, x.W/8)
        self.layer2_1x1 = convrelu(128, 128, 1, 0)
        self.layer3 = self.base_layers[6]  # size=(N, 256, x.H/16, x.W/16)
        self.layer3_1x1 = convrelu(256, 256, 1, 0)
        self.layer4 = self.base_layers[7]  # size=(N, 512, x.H/32, x.W/32)
        self.layer4_1x1 = convrelu(512, 512, 1, 0)

        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv_up3 = convrelu(256 + 512, 512, 3, 1)
        self.conv_up2 = convrelu(128 + 512, 256, 3, 1)
        self.conv_up1 = convrelu(64 + 256, 256, 3, 1)
        self.conv_up0 = convrelu(64 + 256, 128, 3, 1)

        self.conv_original_size0 = convrelu(1, 64, 3, 1)
        self.conv_original_size1 = convrelu(64, 64, 3, 1)
        self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1)

        self.conv_last = nn.Conv2d(64, n_class, 1)

    def forward(self, x1, x2, x3, x4, x5, x6, x7, x8):
        x_original = self.conv_original_size0(x1, x2, x3, x4, x5, x6, x7, x8)
        x_original = self.conv_original_size1(x_original)

        #layer0 = self.layer0(input)
        layer0 = self.layer0(x1, x2, x3, x4, x5, x6, x7, x8)
        layer1 = self.layer1(layer0)
        layer2 = self.layer2(layer1)
        layer3 = self.layer3(layer2)
        layer4 = self.layer4(layer3)

        layer4 = self.layer4_1x1(layer4)
        x = self.upsample(layer4)
        layer3 = self.layer3_1x1(layer3)
        x = torch.cat([x, layer3], dim=1)
        x = self.conv_up3(x)

        x = self.upsample(x)
        layer2 = self.layer2_1x1(layer2)
        x = torch.cat([x, layer2], dim=1)
        x = self.conv_up2(x)

        x = self.upsample(x)
        layer1 = self.layer1_1x1(layer1)
        x = torch.cat([x, layer1], dim=1)
        x = self.conv_up1(x)

        x = self.upsample(x)
        layer0 = self.layer0_1x1(layer0)
        x = torch.cat([x, layer0], dim=1)
        x = self.conv_up0(x)

        x = self.upsample(x)
        x = torch.cat([x, x_original], dim=1)
        x = self.conv_original_size2(x)

        out = self.conv_last(x)

        return out

My training loop is:

def train_single_epoch(model, data_loader, loss_fn, optimiser, device):
    for input1, input2, input3, input4, input5, input6, input7, input8, target in data_loader:
        input1, input2, input3, input4, input5, input6, input7, input8, target = input1.to(device), input2.to(device), input3.to(device), input4.to(device), input5.to(device), input6.to(device), input7.to(device), input8.to(device), target.to(device)

        prediction = model(input1, input2, input3, input4, input5, input6, input7, input8)
        loss = loss_fn(prediction, target)

        optimiser.zero_grad()
        loss.backward()
        optimiser.step()

    print(f"Training loss: {loss.item()}")

def train(model, data_loader, loss_fn, optimiser, device, epochs):
    for i in range(epochs):
        print(f"Epoch {i+1}")
        train_single_epoch(model, data_loader, loss_fn, optimiser, device)
        for vinput1, vinput2, vinput3, vinput4, vinput5, vinput6, vinput7, vinput8, vtarget in data_loader:
            vinput1, vinput2, vinput3, vinput4, vinput5, vinput6, vinput7, vinput8, vtarget = vinput1.to(device), vinput2.to(device), vnput3.to(device), vinput4.to(device), vinput5.to(device), vinput6.to(device), vinput7.to(device), vinput8.to(device), vtarget.to(device)

            vprediction = model(vinput1, vinput2, vinput3, vinput4, vinput5, vinput6, vinput7, vinput8)
            vloss = loss_fn(vprediction, vtarget)

        print(f"Validation loss: {vloss.item()}")
        print("---------------------------")
    print("Finished training")

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
print(f"Using {device}")

net = ResNetUNet(n_class=2).to(device)

loss_fn = nn.L1Loss()
optimiser = torch.optim.Adam(net.parameters(),
                             lr=0.001)

train(net, train_loader, loss_fn, optimiser, device, 10)

When I go to train the model I get the following error: TypeError: forward() takes 2 positional arguments but 9 were given.

Full traceback:

TypeError                                 Traceback (most recent call last)
Input In [52], in <cell line: 17>()
     13 optimiser = torch.optim.Adam(net.parameters(),
     14                              lr=0.001)
     16 # train model
---> 17 train(net, train_loader, loss_fn, optimiser, device, 10)

Input In [51], in train(model, data_loader, loss_fn, optimiser, device, epochs)
     18 for i in range(epochs):
     19     print(f"Epoch {i+1}")
---> 20     train_single_epoch(model, data_loader, loss_fn, optimiser, device)
     21     for vinput1, vinput2, vinput3, vinput4, vinput5, vinput6, vinput7, vinput8, vtarget in data_loader:
     22         vinput1, vinput2, vinput3, vinput4, vinput5, vinput6, vinput7, vinput8, vtarget = vinput1.to(device), vinput2.to(device), vnput3.to(device), vinput4.to(device), vinput5.to(device), vinput6.to(device), vinput7.to(device), vinput8.to(device), vtarget.to(device)

Input In [51], in train_single_epoch(model, data_loader, loss_fn, optimiser, device)
      3 input1, input2, input3, input4, input5, input6, input7, input8, target = input1.to(device), input2.to(device), input3.to(device), input4.to(device), input5.to(device), input6.to(device), input7.to(device), input8.to(device), target.to(device)
      5 # calculate loss
----> 6 prediction = model(input1, input2, input3, input4, input5, input6, input7, input8)
      7 loss = loss_fn(prediction, target)
      9 # backpropagate error and update weights

File ~\anaconda3\envs\TorchCuda\lib\site-packages\torch\nn\modules\module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

Input In [50], in ResNetUNet.forward(self, x1, x2, x3, x4, x5, x6, x7, x8)
     41 def forward(self, x1, x2, x3, x4, x5, x6, x7, x8):
---> 42     x_original = self.conv_original_size0(x1, x2, x3, x4, x5, x6, x7, x8)
     43     x_original = self.conv_original_size1(x_original)
     45     #layer0 = self.layer0(input)

File ~\anaconda3\envs\TorchCuda\lib\site-packages\torch\nn\modules\module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

TypeError: forward() takes 2 positional arguments but 9 were given

I’m unsure on how to properly introduce multiple inputs to my network so any help is appreciated. I’m still fairly new to PyTorch so any other pointers on errors in my net (or better ways to implement things) are welcome.

There have been many similar questions like this. There can be no multiple inputs. The input can be only one input x, but there can be many forwards, or many networks that you can run.

If you really need one net, then use one net, but the input should be the concatenation of input1, input2, etc. and you should abuse that in forward.

Thanks for the reply @blackbirdbarber, I have seen an answer to a similar question where *inputs is used, for instance:

*inputs, target = next(iter(train_loader))

which gives me a list of the input tensors. Could I then use this in the forward function and concatenate before the first layer? Also is torch.cat or torch.stack better in this case?

In order to use the above line you need to have the train_loader defined as DataLoader.
The thing is there are just two classes in PyTorch data loading pipeline:

Mark them in red. DataLoader first param will be the Dataset. The class you marked first in red.

But first what do you do SSL or supervised learning?

Yes the line works, inspecting the shape of elements in inputs after that line gives the same output as in my question:

inputs[7].shape
>>> torch.Size([1, 1, 1024, 160])

train_loader is a DataLoader:

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=2)

where

train_dataset = TensorDataset(input1_specgrams, input2_specgrams, input3_specgrams, input4_specgrams, input5_specgrams, input6_specgrams, input7_specgrams, input8_specgrams, target_specgrams)

This is a supervised learning task, I’m unfamiliar with SSL (semi-supervised learning?) but all of my data has true targets.

I am a cat lover. Aren’t you?
stack concatenates sequence of tensors along a new dimension, you don’t do that often.