I got the below error, how can I solve the problem:
RuntimeError: torch.cat(): Sizes of tensors must match except in dimension 1. Got 2 and 4 in dimension 2 (The offending index is 1)
class model(nn.Module):
'''
Discriminator Class
Values:
base_channels: the number of base channels, a scalar
n_classes: the number of image classes, a scalar
'''
def __init__(self, base_channels=64, n_classes=2):
super().__init__()
self.d_blocks1= nn.Sequential(
DResidualBlock(3, base_channels, downsample=True, use_preactivation=False),
AttentionBlock(base_channels), nn.ReLU(inplace=True))
self.d_blocks2= nn.Sequential(
DResidualBlock(base_channels, 2 * base_channels, downsample=True, use_preactivation=True),
AttentionBlock(2 * base_channels), nn.ReLU(inplace=True))
self.d_blocks3= nn.Sequential(
DResidualBlock(2 * base_channels, 4 * base_channels, downsample=True, use_preactivation=True),
AttentionBlock(4 * base_channels), nn.ReLU(inplace=True))
self.d_blocks4= nn.Sequential(
DResidualBlock(4 * base_channels, 8 * base_channels, downsample=True, use_preactivation=True),
AttentionBlock(8 * base_channels), nn.ReLU(inplace=True))
#Bridge
self.d_blocks5= nn.Sequential(
DResidualBlock(8 * base_channels, 16 * base_channels, downsample=True, use_preactivation=True),
AttentionBlock(16 * base_channels), nn.ReLU(inplace=True))
# self.bridge = DResidualBlock(8 * base_channels, 16 * downsample=True, use_preactivation=True)
self.up1 = Upsample(16 * base_channels, 16* base_channels)
self.g_blocks1 = nn.ModuleList([
GResidualBlock_UNET(16 * base_channels, 8 * base_channels, skip_connection =True),
AttentionBlock(8 * base_channels),
])
self.up2 = Upsample(8 * base_channels, 8* base_channels)
self.g_blocks2 = nn.ModuleList([
GResidualBlock_UNET( 8 * base_channels, 4 * base_channels, skip_connection =True),
AttentionBlock(4 * base_channels),
])
self.up3 = nn.Upsample(4 * base_channels, 4* base_channels)
self.g_blocks3 = nn.ModuleList([
GResidualBlock_UNET( 4 * base_channels, 2 * base_channels, skip_connection =True ),
AttentionBlock(2 * base_channels),
])
self.up4 = nn.Upsample(2 * base_channels, 2* base_channels)
self.g_blocks4 = nn.ModuleList([
GResidualBlock_UNET(2 * base_channels, base_channels, skip_connection =True),
AttentionBlock(base_channels),
])
self.conv2 = nn.utils.spectral_norm(nn.Conv2d(base_channels, 1, 1, 1))
self.sigmoid = nn.Sigmoid()
self.relu = nn.ReLU()
self.proj_o = nn.utils.spectral_norm(nn.Linear(16 * base_channels, 1))
def forward(self, x):
x1 = x.contiguous()
x1 = x1.view(-1, 3, base_channels, base_channels)
x1 = self.d_blocks1(x1)
x2 = self.d_blocks2(x1)
x3 = self.d_blocks3(x2)
x4 = self.d_blocks4(x3)
x5 = self.d_blocks5(x4)
h6 = self.up1(x5)
h6 = torch.cat([h6, x4], dim=1)
h6 = self.g_blocks1(h6)
h7 = self.up2(h6)
h7 = torch.cat([h7, x3], dim=1)
h7 = self.g_blocks2(h7)
h8 = self.up3(h7)
h8 = torch.cat([h8, x2], dim=1)
h8 = self.g_blocks3(h8)
h9 = self.up4(h8)
h9 = torch.cat([h9, x1], dim=1)
h9 = self.g_blocks4(h9)
h = self.conv2(h9)
h = self. sigmoid(h)
h = self.relu(h)
h = torch.sum(h, dim=[2,3])
# Class-unconditional output
uncond_out = self.proj_o(h)
return uncond_out ```