I am trying to take someone elses code and replace multiplication with concatenation because it’s supposed to improve performance.
The only line i changed from the original code is
x *= ds_outputs[-i - 1] to x = torch.cat((x, ds_outputs[-i - 1]), 1)
and I get the error
RuntimeError: Given groups=1, weight of size [160, 160, 3, 3], expected input[2, 320, 16, 128] to have 160 channels, but got 320 channels instead.
Please can anyone help with this?
l =3, g = 32, k = 3, bn = 8, dim_c = 4, scale = (2, 2), dim_f = 2048 . and here is the model code
self.first_conv = nn.Sequential( nn.Conv2d(in_channels=self.dim_c, out_channels=g, kernel_size=(1, 1)), nn.GroupNorm(4, g), nn.ReLU(), ) f = self.dim_f c = g self.encoding_blocks = nn.ModuleList() self.ds = nn.ModuleList() for i in range(self.n): self.encoding_blocks.append(TFC_TDF(c, l, f, k, bn, bias=bias)) self.ds.append( nn.Sequential( nn.Conv2d(in_channels=c, out_channels=c + g, kernel_size=scale, stride=scale), nn.GroupNorm(4, c + g), nn.ReLU() ) ) f = f // 2 c += g self.bottleneck_block = TFC_TDF(c, l, f, k, bn, bias=bias) self.decoding_blocks = nn.ModuleList() self.us = nn.ModuleList() for i in range(self.n): self.us.append( nn.Sequential( nn.ConvTranspose2d(in_channels=c, out_channels=c - g, kernel_size=scale, stride=scale), nn.GroupNorm(4, c - g), nn.ReLU() ) ) f = f * 2 c -= g self.decoding_blocks.append(TFC_TDF(c, l, f, k, bn, bias=bias)) self.final_conv = nn.Sequential( nn.Conv2d(in_channels=c, out_channels=self.dim_c, kernel_size=(1, 1)), ) def forward(self, x): x = self.first_conv(x) x = x.transpose(-1, -2) ds_outputs = [] for i in range(self.n): x = self.encoding_blocks[i](x) ds_outputs.append(x) x = self.ds[i](x) x = self.bottleneck_block(x) for i in range(self.n): x = self.us[i](x) x = torch.cat((x, ds_outputs[-i - 1]), 1) x = self.decoding_blocks[i](x) x = x.transpose(-1, -2) x = self.final_conv(x) return x
class TFC(nn.Module):
def init(self, c, l, k):
super(TFC, self).init()self.H = nn.ModuleList() for i in range(l): self.H.append( nn.Sequential( nn.Conv2d(in_channels=c, out_channels=c, kernel_size=k, stride=1, padding=k // 2), nn.GroupNorm(4, c), nn.ReLU(), ) ) def forward(self, x): for h in self.H: x = h(x) return x
class DenseTFC(nn.Module):
def init(self, c, l, k):
super(DenseTFC, self).init()self.conv = nn.ModuleList() for i in range(l): self.conv.append( nn.Sequential( nn.Conv2d(in_channels=c, out_channels=c, kernel_size=k, stride=1, padding=k // 2), nn.GroupNorm(4, c), nn.ReLU(), ) ) def forward(self, x): for layer in self.conv[:-1]: x = torch.cat([layer(x), x], 1) return self.conv[-1](x)
class TFC_TDF(nn.Module):
def init(self, c, l, f, k, bn, dense=False, bias=True):super(TFC_TDF, self).__init__() self.use_tdf = bn is not None self.tfc = DenseTFC(c, l, k) if dense else TFC(c, l, k) if self.use_tdf: if bn == 0: self.tdf = nn.Sequential( nn.Linear(f, f, bias=bias), nn.GroupNorm(4, c), nn.ReLU() ) else: self.tdf = nn.Sequential( nn.Linear(f, f // bn, bias=bias), nn.GroupNorm(4, c), nn.ReLU(), nn.Linear(f // bn, f, bias=bias), nn.GroupNorm(4, c), nn.ReLU() ) def forward(self, x): x = self.tfc(x) return x + self.tdf(x) if self.use_tdf else x