[resolved] How to understand the densenet implementation?

The densenet states that “the ith layer receives the feature-maps of all preceding layers”. From pytorch, the code is:

class _DenseLayer(nn.Sequential):
    def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
        super(_DenseLayer, self).__init__()
        self.add_module('norm.1', nn.BatchNorm2d(num_input_features)),
        self.add_module('relu.1', nn.ReLU(inplace=True)),
        self.add_module('conv.1', nn.Conv2d(num_input_features, bn_size *
                        growth_rate, kernel_size=1, stride=1, bias=False)),
        self.add_module('norm.2', nn.BatchNorm2d(bn_size * growth_rate)),
        self.add_module('relu.2', nn.ReLU(inplace=True)),
        self.add_module('conv.2', nn.Conv2d(bn_size * growth_rate, growth_rate,
                        kernel_size=3, stride=1, padding=1, bias=False)),
        self.drop_rate = drop_rate

    def forward(self, x):
        new_features = super(_DenseLayer, self).forward(x)
        if self.drop_rate > 0:
            new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
        return torch.cat([x, new_features], 1)

May I ask that how to understand the forward method in above code, especially the new_features = super(_DenseLayer, self).forward(x)?
Thanks!

2 Likes

Answerd by myself: it seems that this line new_features = super(_DenseLayer, self).forward(x) will result in a recursive call which meet the densenet requirement exactly.

1 Like

Hi I still don’t understand why it result in a recursive call. Could you explain more details?

Hoping to show more details about your understanding. I also have the same question.

I was confused by the pattern as well, but it makes sense if you understand the parent class. I figured it works as such:

    super(_DenseLayer, self).forward(x) 

calls on the forward() method of the parent class of _DenseLayer (nn.Sequential) through the super keyword, which is a more elegant way to reference the parent class from which a class inherits. The source code of torch.nn.Sequential.forward() has following code in place for the method:

    def forward(self, input):
       for module in self._modules.values():
           input = module(input)
       return input

So it takes all the modules which were added to the class using self.add_module and passes the values through each model step, which it then uses for the inputs again. Essentially, this lets it execute the entire block without using a bunch of function calls to clutter the code.

Hopefully this puts the question to rest!

2 Likes