Forward() takes 1 positional argument but 2 were given

I’m a pytorch beginner, i try to write a unet, this is my code, when i use pytorch summary to summary my model output, i got this error: TypeError: forward() takes 1 positional argument but 2 were given

i don’t know why, who can help me?

class DownSample(nn.Module):

    def __init__(self, in_planes: int, out_planes: int, kernel_size: int):
        super(DownSample, self).__init__()

        self.down = nn.Sequential(
            nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2, padding=1),
            nn.BatchNorm2d(out_planes),
            nn.LeakyReLU()
        )

        init_weight.initialize(self)

    def forward(self, x):
        return self.down(x)


class UpSample(nn.Module):

    def __init__(self, in_planes: int, out_planes: int,
                 kernel_size: int, padding: int, output_padding: int,
                 apply_dropout: bool = False):
        super(UpSample, self).__init__()

        self.up = nn.ModuleList()

        self.up.append(
            nn.ConvTranspose2d(in_planes, out_planes, kernel_size, stride=2,
                               padding=padding, output_padding=output_padding),
        )
        self.up.append(nn.BatchNorm2d(out_planes))
        if apply_dropout:
            self.up.append(nn.Dropout())
        self.up.append(nn.LeakyReLU())

        init_weight.initialize(self)

    def forward(self, inputs):
        return self.up(inputs)
class UNet(nn.Module):

    def __init__(self):
        super(UNet, self).__init__()

        down_stack = [
            pix2pix.DownSample(3, 64, 4),
            pix2pix.DownSample(64, 128, 4),
            pix2pix.DownSample(128, 256, 4),
            pix2pix.DownSample(256, 512, 4),
            pix2pix.DownSample(512, 512, 4),
            pix2pix.DownSample(512, 512, 4),
            pix2pix.DownSample(512, 512, 4),
            pix2pix.DownSample(512, 512, 4),
        ]

        up_stack = [
            pix2pix.UpSample(512, 512, 4, 1, 1, True),
            pix2pix.UpSample(512, 512, 4, 1, 1, True),
            pix2pix.UpSample(512, 512, 4, 1, 1, True),
            pix2pix.UpSample(512, 512, 4, 1, 1, True),
            pix2pix.UpSample(512, 256, 4, 1, 1, True),
            pix2pix.UpSample(256, 128, 4, 1, 1, True),
            pix2pix.UpSample(256, 128, 4, 1, 1, True),
            pix2pix.UpSample(128, 64, 4, 1, 1, True),
        ]

        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()

        for item in down_stack:
            self.encoder.append(item)

        for item in up_stack:
            self.decoder.append(item)

    def forward(self, inputs):
        feat = inputs
        for i in range(len(self.encoder)):
            feat = self.encoder[i](feat)

        for i in range(len(self.decoder)):
            feat = self.decoder[i](feat)

        return feat


if __name__ == '__main__':
    from torchsummaryX import summary

    import torch
    x = torch.ones((1, 3, 512, 512))
    u = UNet()

    summary(model=u, x=x)

That is because you are using nn.ModuleList() inside your Upsample() class. You should change it to nn.Sequential(). One way to do this is like the following:

class UpSample(nn.Module):

    def __init__(self, in_planes: int, out_planes: int,
                 kernel_size: int, padding: int, output_padding: int,
                 apply_dropout: bool = False):
        super(UpSample, self).__init__()

        self.up = nn.ModuleList()

        self.up.append(
            nn.ConvTranspose2d(in_planes, out_planes, kernel_size, stride=2,
                               padding=padding, output_padding=output_padding),
        )
        self.up.append(nn.BatchNorm2d(out_planes))
        if apply_dropout:
            self.up.append(nn.Dropout())
        self.up.append(nn.LeakyReLU())

        self.up = nn.Sequential(*self.up)    #### Use nn.Sequential() #####

        init_weight.initialize(self)

    def forward(self, inputs):
        return self.up(inputs)

For more info about nn.ModuleList() and nn.Sequential() you can read here: When should I use nn.ModuleList and when should I use nn.Sequential?

thanks very much, it works good!