I’m a pytorch beginner, i try to write a unet, this is my code, when i use pytorch summary to summary my model output, i got this error: TypeError: forward() takes 1 positional argument but 2 were given
i don’t know why, who can help me?
class DownSample(nn.Module):
def __init__(self, in_planes: int, out_planes: int, kernel_size: int):
super(DownSample, self).__init__()
self.down = nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2, padding=1),
nn.BatchNorm2d(out_planes),
nn.LeakyReLU()
)
init_weight.initialize(self)
def forward(self, x):
return self.down(x)
class UpSample(nn.Module):
def __init__(self, in_planes: int, out_planes: int,
kernel_size: int, padding: int, output_padding: int,
apply_dropout: bool = False):
super(UpSample, self).__init__()
self.up = nn.ModuleList()
self.up.append(
nn.ConvTranspose2d(in_planes, out_planes, kernel_size, stride=2,
padding=padding, output_padding=output_padding),
)
self.up.append(nn.BatchNorm2d(out_planes))
if apply_dropout:
self.up.append(nn.Dropout())
self.up.append(nn.LeakyReLU())
init_weight.initialize(self)
def forward(self, inputs):
return self.up(inputs)
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
down_stack = [
pix2pix.DownSample(3, 64, 4),
pix2pix.DownSample(64, 128, 4),
pix2pix.DownSample(128, 256, 4),
pix2pix.DownSample(256, 512, 4),
pix2pix.DownSample(512, 512, 4),
pix2pix.DownSample(512, 512, 4),
pix2pix.DownSample(512, 512, 4),
pix2pix.DownSample(512, 512, 4),
]
up_stack = [
pix2pix.UpSample(512, 512, 4, 1, 1, True),
pix2pix.UpSample(512, 512, 4, 1, 1, True),
pix2pix.UpSample(512, 512, 4, 1, 1, True),
pix2pix.UpSample(512, 512, 4, 1, 1, True),
pix2pix.UpSample(512, 256, 4, 1, 1, True),
pix2pix.UpSample(256, 128, 4, 1, 1, True),
pix2pix.UpSample(256, 128, 4, 1, 1, True),
pix2pix.UpSample(128, 64, 4, 1, 1, True),
]
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
for item in down_stack:
self.encoder.append(item)
for item in up_stack:
self.decoder.append(item)
def forward(self, inputs):
feat = inputs
for i in range(len(self.encoder)):
feat = self.encoder[i](feat)
for i in range(len(self.decoder)):
feat = self.decoder[i](feat)
return feat
if __name__ == '__main__':
from torchsummaryX import summary
import torch
x = torch.ones((1, 3, 512, 512))
u = UNet()
summary(model=u, x=x)