I’m trying to get a U-Net model to take multiple inputs (8 separate audio spectrograms of torch.Size([1, 1024, 160])
) and give a single output (a stereo audio mixture of the 8 tracks of torch.Size([2, 1024, 160])
). I’m unsure how to write out the forward
function of the net for my purpose. My DataLoader
appears to be implemented correctly (with batch_size = 1
):
input1, input2, input3, input4, input5, input6, input7, input8, target = next(iter(train_loader))
input1.shape
>>> torch.Size([1, 1, 1024, 160])
target.shape
>>> torch.Size([1, 2, 1024, 160])
My model architecture is:
def convrelu(in_channels, out_channels, kernel, padding):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
nn.ReLU(inplace=True),
)
class ResNetUNet(nn.Module):
def __init__(self, n_class):
super().__init__()
self.base_model = models.resnet18(pretrained=True)
self.base_layers = list(self.base_model.children())
self.layer0 = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False),
*self.base_layers[1:3]) # size=(N, 64, x.H/2, x.W/2)
self.layer0_1x1 = convrelu(64, 64, 1, 0)
self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4)
self.layer1_1x1 = convrelu(64, 64, 1, 0)
self.layer2 = self.base_layers[5] # size=(N, 128, x.H/8, x.W/8)
self.layer2_1x1 = convrelu(128, 128, 1, 0)
self.layer3 = self.base_layers[6] # size=(N, 256, x.H/16, x.W/16)
self.layer3_1x1 = convrelu(256, 256, 1, 0)
self.layer4 = self.base_layers[7] # size=(N, 512, x.H/32, x.W/32)
self.layer4_1x1 = convrelu(512, 512, 1, 0)
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv_up3 = convrelu(256 + 512, 512, 3, 1)
self.conv_up2 = convrelu(128 + 512, 256, 3, 1)
self.conv_up1 = convrelu(64 + 256, 256, 3, 1)
self.conv_up0 = convrelu(64 + 256, 128, 3, 1)
self.conv_original_size0 = convrelu(1, 64, 3, 1)
self.conv_original_size1 = convrelu(64, 64, 3, 1)
self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1)
self.conv_last = nn.Conv2d(64, n_class, 1)
def forward(self, x1, x2, x3, x4, x5, x6, x7, x8):
x_original = self.conv_original_size0(x1, x2, x3, x4, x5, x6, x7, x8)
x_original = self.conv_original_size1(x_original)
#layer0 = self.layer0(input)
layer0 = self.layer0(x1, x2, x3, x4, x5, x6, x7, x8)
layer1 = self.layer1(layer0)
layer2 = self.layer2(layer1)
layer3 = self.layer3(layer2)
layer4 = self.layer4(layer3)
layer4 = self.layer4_1x1(layer4)
x = self.upsample(layer4)
layer3 = self.layer3_1x1(layer3)
x = torch.cat([x, layer3], dim=1)
x = self.conv_up3(x)
x = self.upsample(x)
layer2 = self.layer2_1x1(layer2)
x = torch.cat([x, layer2], dim=1)
x = self.conv_up2(x)
x = self.upsample(x)
layer1 = self.layer1_1x1(layer1)
x = torch.cat([x, layer1], dim=1)
x = self.conv_up1(x)
x = self.upsample(x)
layer0 = self.layer0_1x1(layer0)
x = torch.cat([x, layer0], dim=1)
x = self.conv_up0(x)
x = self.upsample(x)
x = torch.cat([x, x_original], dim=1)
x = self.conv_original_size2(x)
out = self.conv_last(x)
return out
My training loop is:
def train_single_epoch(model, data_loader, loss_fn, optimiser, device):
for input1, input2, input3, input4, input5, input6, input7, input8, target in data_loader:
input1, input2, input3, input4, input5, input6, input7, input8, target = input1.to(device), input2.to(device), input3.to(device), input4.to(device), input5.to(device), input6.to(device), input7.to(device), input8.to(device), target.to(device)
prediction = model(input1, input2, input3, input4, input5, input6, input7, input8)
loss = loss_fn(prediction, target)
optimiser.zero_grad()
loss.backward()
optimiser.step()
print(f"Training loss: {loss.item()}")
def train(model, data_loader, loss_fn, optimiser, device, epochs):
for i in range(epochs):
print(f"Epoch {i+1}")
train_single_epoch(model, data_loader, loss_fn, optimiser, device)
for vinput1, vinput2, vinput3, vinput4, vinput5, vinput6, vinput7, vinput8, vtarget in data_loader:
vinput1, vinput2, vinput3, vinput4, vinput5, vinput6, vinput7, vinput8, vtarget = vinput1.to(device), vinput2.to(device), vnput3.to(device), vinput4.to(device), vinput5.to(device), vinput6.to(device), vinput7.to(device), vinput8.to(device), vtarget.to(device)
vprediction = model(vinput1, vinput2, vinput3, vinput4, vinput5, vinput6, vinput7, vinput8)
vloss = loss_fn(vprediction, vtarget)
print(f"Validation loss: {vloss.item()}")
print("---------------------------")
print("Finished training")
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
print(f"Using {device}")
net = ResNetUNet(n_class=2).to(device)
loss_fn = nn.L1Loss()
optimiser = torch.optim.Adam(net.parameters(),
lr=0.001)
train(net, train_loader, loss_fn, optimiser, device, 10)
When I go to train the model I get the following error: TypeError: forward() takes 2 positional arguments but 9 were given
.
Full traceback:
TypeError Traceback (most recent call last)
Input In [52], in <cell line: 17>()
13 optimiser = torch.optim.Adam(net.parameters(),
14 lr=0.001)
16 # train model
---> 17 train(net, train_loader, loss_fn, optimiser, device, 10)
Input In [51], in train(model, data_loader, loss_fn, optimiser, device, epochs)
18 for i in range(epochs):
19 print(f"Epoch {i+1}")
---> 20 train_single_epoch(model, data_loader, loss_fn, optimiser, device)
21 for vinput1, vinput2, vinput3, vinput4, vinput5, vinput6, vinput7, vinput8, vtarget in data_loader:
22 vinput1, vinput2, vinput3, vinput4, vinput5, vinput6, vinput7, vinput8, vtarget = vinput1.to(device), vinput2.to(device), vnput3.to(device), vinput4.to(device), vinput5.to(device), vinput6.to(device), vinput7.to(device), vinput8.to(device), vtarget.to(device)
Input In [51], in train_single_epoch(model, data_loader, loss_fn, optimiser, device)
3 input1, input2, input3, input4, input5, input6, input7, input8, target = input1.to(device), input2.to(device), input3.to(device), input4.to(device), input5.to(device), input6.to(device), input7.to(device), input8.to(device), target.to(device)
5 # calculate loss
----> 6 prediction = model(input1, input2, input3, input4, input5, input6, input7, input8)
7 loss = loss_fn(prediction, target)
9 # backpropagate error and update weights
File ~\anaconda3\envs\TorchCuda\lib\site-packages\torch\nn\modules\module.py:1130, in Module._call_impl(self, *input, **kwargs)
1126 # If we don't have any hooks, we want to skip the rest of the logic in
1127 # this function, and just call forward.
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []
Input In [50], in ResNetUNet.forward(self, x1, x2, x3, x4, x5, x6, x7, x8)
41 def forward(self, x1, x2, x3, x4, x5, x6, x7, x8):
---> 42 x_original = self.conv_original_size0(x1, x2, x3, x4, x5, x6, x7, x8)
43 x_original = self.conv_original_size1(x_original)
45 #layer0 = self.layer0(input)
File ~\anaconda3\envs\TorchCuda\lib\site-packages\torch\nn\modules\module.py:1130, in Module._call_impl(self, *input, **kwargs)
1126 # If we don't have any hooks, we want to skip the rest of the logic in
1127 # this function, and just call forward.
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []
TypeError: forward() takes 2 positional arguments but 9 were given
I’m unsure on how to properly introduce multiple inputs to my network so any help is appreciated. I’m still fairly new to PyTorch so any other pointers on errors in my net (or better ways to implement things) are welcome.