I want to make a skip connections in pytorch but I got a lot of errors until I reached to this code
class Block(nn.Module):
def __init__(self, input, output, p, k=3, padding=1):
super(Block, self).__init__()
self.input, self.output, self.p, self.k, self.padding = input, output, p, k, padding
self.conv = nn.Conv2d(input, output, k, bias=False, padding=self.padding)
self.BatchNorm2d = nn.BatchNorm2d(output)
self.Dropout2d = nn.Dropout2d(p)
def forward(self, x):
x = F.leaky_relu(self.conv(x), 0.01)
x = self.Dropout2d(x)
x = self.BatchNorm2d(x)
return x
class Net(nn.Module):
def __init__(self, ):
super(Net, self).__init__() # just run the init of parent class (nn.Module)
self.block1_1 = Block(input=1, output=10, p=.2, kernel=5, padding=3)
self.block2_1 = Block(10, 30, .5, 4, 2)
self.block3_1 = Block(30, 40, .6)
self.block4_1 = Block(40, 50, .8)
self.block1_2 = Block(1, 10, .2, 5, 3)
self.block2_2 = Block(10, 30, .5, 4, 2)
self.block3_2 = Block(30, 40, .6)
self.block4_2 = Block(40, 50, .8)
x = torch.randn(28, 28).view(-1, 1, 28, 28)
self._to_linear = None
self.convs(x)
self.fc1 = nn.Linear(self._to_linear, 512, bias=False) # flattening.
self.BatchNorm1d_1 = nn.BatchNorm1d(512)
self.Dropout1d = nn.Dropout(0.5)
self.fc2 = nn.Linear(512, 10)
def convs(self, x):
x1_1 = F.max_pool2d(self.block1_1(x), (2, 2))
x2_1 = F.max_pool2d(self.block2_1(x1_1), (2, 2))
x3_1 = F.max_pool2d(self.block3_1(x2_1), (2, 2))
x4_1 = F.max_pool2d(self.block4_1(x3_1), (2, 2))
x1_2 = F.max_pool2d(self.block1_2(x), (2, 2))
x2_2 = F.max_pool2d(self.block2_2(x1_2), (2, 2))
x3_2 = F.max_pool2d(self.block3_2(x2_2), (2, 2))
x4_2 = F.max_pool2d(self.block4_2(x3_2), (2, 2))
x = x4_2 + x4_1
if self._to_linear is None:
self._to_linear = int(x[0].shape[0] * x[0].shape[1] * x[0].shape[2])
return x
def forward(self, x):
x = self.convs(x)
x = x.view(-1, self._to_linear)
x = F.relu(self.fc1(x))
x = self.Dropout1d(x)
x = self.BatchNorm1d_1(x)
x = self.fc2(x)
return x
Actually it worked but it wasn’t what I want.
I think there is a lot of ways to create skip connections and maybe there are some functions in torch.nn that will help me.
So if any one has an experience with skip connections I will be so pleasure to ask for your opinion.