I encountered this error while running UNet. My code is as follows. What is the issue? My input image is 256 x 256.
from unet_parts import *
class UNet(nn.Module):
def init(self, n_channels=1, n_classes=1, bilinear=False):
super(UNet, self).init()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inc = (DoubleConv(n_channels, 64))
self.down1 = (Down(64, 128))
self.down2 = (Down(128, 256))
self.down3 = (Down(256, 512))
factor = 2 if bilinear else 1
self.down4 = (Down(512, 1024 // factor))
self.up1 = (Up(1024, 512 // factor, bilinear))
self.up2 = (Up(512, 256 // factor, bilinear))
self.up3 = (Up(256, 128 // factor, bilinear))
self.up4 = (Up(128, 64, bilinear))
self.outc1 = (OutConv(64, n_classes))
self.outc2 = (OutConv(64, n_classes))
self.outc3 = (OutConv(64, n_classes))
self.outc4 = (OutConv(64, n_classes))
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
fake_bin1 = self.outc1(x)
fake_bin2 = self.outc2(x)
fake_bin3 = self.outc3(x)
fake_bin4 = self.outc4(x)
return fake_bin1, fake_bin2, fake_bin3, fake_bin4
def use_checkpointing(self):
self.inc = torch.utils.checkpoint(self.inc)
self.down1 = torch.utils.checkpoint(self.down1)
self.down2 = torch.utils.checkpoint(self.down2)
self.down3 = torch.utils.checkpoint(self.down3)
self.down4 = torch.utils.checkpoint(self.down4)
self.up1 = torch.utils.checkpoint(self.up1)
self.up2 = torch.utils.checkpoint(self.up2)
self.up3 = torch.utils.checkpoint(self.up3)
self.up4 = torch.utils.checkpoint(self.up4)
self.outc1 = torch.utils.checkpoint(self.outc1)
self.outc2 = torch.utils.checkpoint(self.outc2)
self.outc3 = torch.utils.checkpoint(self.outc3)
self.outc4 = torch.utils.checkpoint(self.outc4)