How to write your own skip connections in PyTorch?

Hello. Following is the code I am using with skip connections:

class PeakNet(nn.Module):
    def __init__(self):
        super(PeakNet, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=31, stride=1, padding=15)
        torch.nn.init.xavier_uniform_(self.conv1.weight)
        self.conv2 = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=31, stride=1, padding=15)
        torch.nn.init.xavier_uniform_(self.conv2.weight)
        self.conv3 = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=31, stride=1, padding=15)
        torch.nn.init.xavier_uniform_(self.conv3.weight)
        self.conv4 = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=31, stride=1, padding=15)
        torch.nn.init.xavier_uniform_(self.conv4.weight)
        self.conv5 = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=31, stride=1, padding=15)
        torch.nn.init.xavier_uniform_(self.conv5.weight)
        self.conv6 = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=31, stride=1, padding=15)
        torch.nn.init.xavier_uniform_(self.conv6.weight)
        self.conv7 = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=31, stride=1, padding=15)
        torch.nn.init.xavier_uniform_(self.conv7.weight)
        self.conv8 = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=31, stride=1, padding=15)
        torch.nn.init.xavier_uniform_(self.conv8.weight)
        self.conv9 = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=31, stride=1, padding=15)
        torch.nn.init.xavier_uniform_(self.conv9.weight)
        self.conv10 = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=31, stride=1, padding=15)
        torch.nn.init.xavier_uniform_(self.conv10.weight)
        self.drop_layer_conv = nn.Dropout(p=0.1)
        self.drop_layer_fc = nn.Dropout(p=0.5)
        self.pool = nn.MaxPool1d(2, 2)
        self.fc1 = nn.Linear(32, 128)
        torch.nn.init.xavier_uniform_(self.fc1.weight)
        self.dense1_bn = nn.BatchNorm1d(128)
        self.fc2 = nn.Linear(128, 64)
        torch.nn.init.xavier_uniform_(self.fc2.weight)
        self.dense2_bn = nn.BatchNorm1d(64)
        self.fc3 = nn.Linear(64, 9)
        torch.nn.init.xavier_uniform_(self.fc3.weight)
        self.skip = nn.Identity()

    def forward(self, x):
        res1 = x
        out1 = self.drop_layer_conv(F.leaky_relu(self.conv1(x)))
        out2 = self.drop_layer_conv(F.leaky_relu(self.conv2(out1)))
        out3 = out2 + self.skip(res1)
        out4 = self.pool(out3)
        res2 = out4
        out5 = self.drop_layer_conv(F.leaky_relu(self.conv3(out4)))
        out6 = self.drop_layer_conv(F.leaky_relu(self.conv4(out5)))
        out7 = out6 + self.skip(res2)
        out8 = self.pool(out7)
        res3 = out8
        out9 = self.drop_layer_conv(F.leaky_relu(self.conv5(out8)))
        out10 = self.drop_layer_conv(F.leaky_relu(self.conv6(out9)))
        out11 = out10 + self.skip(res3)
        out12 = self.pool(out11)
        res4 = out12
        out13 = self.drop_layer_conv(F.leaky_relu(self.conv7(out12)))
        out14 = self.drop_layer_conv(F.leaky_relu(self.conv8(out13)))
        out15 = out14 + self.skip(res4)
        out16 = self.pool(out15)
        res5 = out16
        out17 = self.drop_layer_conv(F.leaky_relu(self.conv9(out16)))
        out18 = self.drop_layer_conv(F.leaky_relu(self.conv10(out17)))
        out19 = out18 + self.skip(res5)
        out20 = self.pool(out19).squeeze(1)
        # print(out12.shape)
        out21 = self.drop_layer_fc(F.leaky_relu(self.dense1_bn(self.fc1(out20))))
        out22 = self.drop_layer_fc(F.relu(self.dense2_bn(self.fc2(out21))))
        out23 = F.relu(self.fc3(out22))
        return out23

But when I look at the gradient flow after every 10 epochs, I am unable to see the advantage of skip connections (i.e. they prevent vanishing gradient), instead, I see the gradients close to the input as having gradients close to 0 while only the fully connected layers have good gradient values.

Can someone verify whether the way I am using skip connections here is correct or not? Thanks!

The skip connections look correct and the resnet implementation uses a similar approach.

I’m not sure, if you really want to apply a relu as the last non-linearity.
If you are dealing with a multi-class classification use case, this would kill all the negative logits.
However, I’m not familiar with your use case.