Hi I build this model for super ressolution and at the end of it i want to split bypass with x
x = self.output(torch.cat([x, x_bypass], dim=1))
This line
It throws error
RuntimeError: Given groups=1, weight of size [3, 64, 5, 5], expected input[3, 67, 512, 512] to have 64 channels, but got 67 channels instead
Any sugesstions how to solve this ?
class Super_ress_model(torch.nn.Module):
def __init__(self, input_shape= (3, 128, 128), output_shape= (3, 384, 384)):
super(Super_ress_model, self).__init__()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.upsample = torch.nn.Upsample(size=None, scale_factor=8, mode='nearest', align_corners=None)
self.input = nn.Conv2d(3, 64, kernel_size=9, stride=1, padding=9 // 2)
self.act1 = nn.ReLU()
self.block0 = ResidualBlock(64,5)
self.block1 = ResidualBlock(64,5)
self.block2 = ResidualBlock(64,5)
self.block3 = ResidualBlock(64,5)
self.upsample0 = torch.nn.Upsample(size=None, scale_factor=2, mode='nearest', align_corners=None)
self.block6 = ResidualBlock(64, 5)
self.block7 = ResidualBlock(64, 5)
self.upsample1 = torch.nn.Upsample(size=None, scale_factor=2, mode='nearest', align_corners=None)
self.block10 = ResidualBlock(64, 5)
self.block11 = ResidualBlock(64, 5)
self.upsample2 = torch.nn.Upsample(size=None, scale_factor=2, mode='nearest', align_corners=None)
self.block13 = ResidualBlock(64,5)
self.block14 = ResidualBlock(64,5)
self.conv1 = nn.Conv2d(64, 64, kernel_size=5, stride=1, padding=5 // 2)
self.act2 = nn.ReLU()
self.output = nn.Conv2d(64, 3, kernel_size=5, stride=1, padding=5 // 2)
torch.nn.init.xavier_uniform_(self.input.weight)
torch.nn.init.zeros_(self.input.bias)
torch.nn.init.xavier_uniform_(self.conv1.weight)
torch.nn.init.zeros_(self.conv1.bias)
torch.nn.init.xavier_uniform_(self.output.weight)
torch.nn.init.zeros_(self.output.bias)
def forward(self,x):
x_bypass = self.upsample(x)
x = self.input(x)
x = self.act1(x)
x = self.block0.forward(x)
x = self.block1.forward(x)
x = self.block2.forward(x)
x = self.block3.forward(x)
x = self.upsample0(x)
x = self.block6.forward(x)
x = self.block7.forward(x)
x = self.upsample1(x)
x = self.block10.forward(x)
x = self.block11.forward(x)
x = self.upsample2(x)
x = self.block13.forward(x)
x = self.block14.forward(x)
x = self.conv1(x)
x = self.act2(x)
x = self.output(torch.cat([x, x_bypass], dim=1))
return x