I want to pass a tuple input to the network, something like
def forward(self, x):
if isinstance(x, tuple):
if len(x) == 2:
y, hint = x
elif isinstance(x, Variable):
y = x
hint = x
if self.outermost:
return self.net(y)
else:
y = self.net(y)
return torch.cat([y, hint], 1)
But when I try to run the model I got the error ValueError: need more than 1 value to unpack
So how should I pass a tuple to the network?
Because I am using a recursive network, passing 2 variables will give an error as well.
The full network is as following:
class UnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
norm_layer=nn.BatchNorm2d, use_dropout=False, gpu_ids=[]):
super(UnetGenerator, self).__init__()
self.gpu_ids = gpu_ids
# currently support only input_nc == output_nc
assert(input_nc == output_nc)
# construct unet structure
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, norm_layer=norm_layer, innermost=True)
for i in range(num_downs - 5):
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(output_nc, ngf, unet_block, outermost=True, norm_layer=norm_layer)
self.model = unet_block
def forward(self, x):
#if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
# return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
#else:
y = self.model(x)
if isinstance(y, tuple):
print(len(y))
exit()
return y
# Defines the submodule with skip connection.
# X -------------------identity---------------------- X
# |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc,
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
downconv = nn.Conv2d(outer_nc, inner_nc, kernel_size=4,
stride=2, padding=1)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(True)
upnorm = norm_layer(outer_nc)
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
elif innermost:
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downrelu, downconv]
up = [uprelu, upconv, upnorm]
model = down + up
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
self.net = nn.Sequential(*model)
def forward(self, x):
if isinstance(x, tuple):
if len(x) == 2:
y, hint = x
else:
y = x
hint = x
if self.outermost:
return self.net(y)
else:
y = self.net(y)
return torch.cat([y, hint], 1)