The skip connections are defined inside of self contained Modules
(Bottleneck & BasicBlock). Since they are done in these modules, they are kept.
If the skip connections were done in the forward
pass of the actual ResNet class, then they would not be kept. Here is the documentation for resnet from timm, directed to the forward
method of the resnet class.
Here is a little dumm example to show what I mean.
The basic block adds 2 to the input in an extra function.
The ResN block does exactly the same, adding 5 instead.
When I run a tensor with ones through the whole model, I get 8 in every value as expected (1+2+5).
If I now do a feature extraction and put them in a nn.Sequential
as you suggested, the result is 3.
It is not the same as the whole module.
The +5 from ResN is gone.
However, the +2 from the block is kept.
Hope this helps.
import torch
class BasicBlock(torch.nn.Module):
def __init__(self):
super().__init__()
def _functional_API(self, x):
return x + 2
def forward(self, x):
x = self._functional_API(x)
return x
class ResN(torch.nn.Module):
def __init__(self):
super().__init__()
self.blk = BasicBlock()
def _functional_API(self, x):
return x + 5
def forward(self, x):
x = self._functional_API(x)
return self.blk(x)
model = ResN()
feature_extraction = torch.nn.Sequential(*list(model.children()))
tensor = torch.ones(5, 5)
print(model(tensor))
print(feature_extraction(tensor))
# Output:
tensor([[8., 8., 8., 8., 8.],
[8., 8., 8., 8., 8.],
[8., 8., 8., 8., 8.],
[8., 8., 8., 8., 8.],
[8., 8., 8., 8., 8.]])
tensor([[3., 3., 3., 3., 3.],
[3., 3., 3., 3., 3.],
[3., 3., 3., 3., 3.],
[3., 3., 3., 3., 3.],
[3., 3., 3., 3., 3.]])