How can I make a non-layered feed forward NN in pytorch with lots of skip connections?

I want to set up a network that may have lots of complicated skip connections, but will still be feed forward (no recurrent connections). Each node will still be a “classic” node in the sense that it will just have a weight for each connection, and then multiply each weight by the corresponding input, sum them, and put it through a nonlinear function.

For example, it might look like this:

However, I don’t think I can easily use something like the nn.Linear module to calculate outputs, because (in the picture for example), the red nodes get inputs from the green nodes and the pink nodes, but the green nodes also need the pink nodes’ outputs to calculate.

I can think of some ways to do this, but they’re real ugly and pytorch usually has a smart way of doing things. Is there a good way to do this? thanks!

You can save outputs and concatenate them to a bigger tensor. In the example below I take the first layer (pink layer) and concatenate it to the second (green) layer. This is input to the red layer. Note that it’s taking the whole layer, not just the pink nodes at the end as in your drawing. Hope you get the picture and can go from there :slight_smile:

Edit: You might have to change the shapes to make them fit. Use the torch.view() function. Also, you might find some inspirational code from a model class named Densenet, which basically does this but for CNNs

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.layer1 = # Init layers

    def forward(self, x):
        x1 = self.layer1(x) # Pink ones
        x2 = self.layer2(x1) # Green ones

        pink_and_green = torch.cat((x1, x2), dim=0)
        x3 = self.layer3(pink_and_green) # Red ones
        return x3

Hi Oli, thanks for the advice. That might work. Something I can’t immediately see a solution to is this though. In that example, they just happened to be nice colors/organized in still kind of a row format. In reality, it might be a lot messier (though still ultimately feed forward). I want to be able to do it programatically, so I can’t use a setup that I’d have to figure out the layers of manually every time…

Well if it’s not layers but more like a chaotic web I’m not sure how you can do that neatly. You could do lots and lots of mini-layers, but to fit them together would take something handcrafted.

Or maybe organise them in layers but put some nodes/weights to zero.

Also, have a look at the NEAT algorithm. That can create these structures but is an evolutionary algorithm

Well if it’s not layers but more like a chaotic web I’m not sure how you can do that neatly. You could do lots and lots of mini-layers, but to fit them together would take something handcrafted.

Or maybe organise them in layers but put some nodes/weights to zero.

Well, I’m currently doing it by hand, by making the network connections myself (via dictionaries and stuff), and then just using lambda functions with the weight tensors… but it’s not pretty.

Also, have a look at the NEAT algorithm. That can create these structures but is an evolutionary algorithm

Guess exactly what I’m using it for? :wink:

1 Like