Creating skip connections across nn.Module's

Hi, :wave: I was looking to implement U-Net Like long skip-connections across the encoder as well as the decoder. This snippet illustrates my problem.

blocks.extend( [nn.BatchNorm2d(self.channel), nn.ReLU(inplace=True)] ) #adding BatchNorm2d

self.blocks = nn.Sequential(*blocks)

def forward(self, input):
    return self.blocks(input)

Because I have 50+ layers in each Module, I can’t expose them individually. But otherwise, I don’t see how I can add skip connections across the encoder to the decoder, like U-NET

Does anyone have any ideas how I can simply do it? I suppose something like a skip connection from ith layer in encoder to len(decoder)-ith layer in the decoder (or maybe a bit complex like U-net, you get the idea)

P.S:- I am using a VQ-VAE-2, but for all intents and purposes its mostly a U-Net added with a quantization layer - so there shouldn’t be any major hiccups :frowning:

Thanks! :hugs:

1 Like

The common approach would be to store the output activations of the encoder layers and reuse them in the decoder path directly. If you are seeing issues using this approach due to the large number of layers, I would suggest to store the outputs in e.g. a list and use a loop to execute the blocks of both parts of the model.
E.g. something like this might work:

def forward(self, x):
    out = x
    acts = []
    for layer in self.encoder:
        out = layer(out)
        acts.append(out)

    for a, layer in zip(acts[::-1], self.decoder):
        out = layer(out, a)

    return out
1 Like

thanks! even if I selectively append layer activations in the list, won’t it still prove a bottleneck for each forward pass as I’ve to doubly iterate, append and compute? :thinking:

I’m not sure why the for loop would create a bottleneck compared to the sequential execution of the layers. Could you explain your concern a bit more, please?

I suppose I was concerned with having so many layers being appended to a list and iterated over - recomputing the same list every single time for each forward pass. I suppose it isn’t really that big of a deal-breaker since I am in the fine-tuning stage, but perhaps someone else may not find the solution to be optimal for their high-performance needs :man_shrugging:

If I understand your concern correctly you re afraid that e.g.

for module in modules:
    out = module(out)

would be slower than:

out = module1(out)
out = module2(out)
out = module3(out)
...
out = module50(out)

If so, then I would claim that the Python overhead (if any) should be small compared to the actual workload execution, but you should certainly profile it if this use case fits your concern.

lovely! :ok_hand: :100:
Just one last question - for your snippet above, we create a skip connection from every encoder layer to every decoder layer. My network is mostly composed of resblocks on both the side (and the task is semantic segmentation)

In that case, should I just add a few (say 7 or 11) long skip connections b/w encoder-decoder or should I still stick with a connection per layer - exactly as your provided snippet? :thinking:

Thanks a ton again!!! :hugs:

Hmm, that’s an interesting question and I would probably start with the “simpler” implementation.
I.e. if your current code would be easier to implement by creating skip connections in each layer, maybe try it out first. Then you could try to add conditions where the skip connections should be used and compare it against the first approach.

1 Like