Hi,It’s pretty common to see extract and process Intermediate activations for some extending operations.
However,due to leak of knowledge of how computation graph works in Pytorch, I have couple questions to ask and separate in two cases :
Case 1: define function inside of class
class MODEL(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = Conv2d(3,32,kernel_size = 3,padding = 1)
self.bn1 = BatchNorm2d(32)
self.pooling1 = MaxPool2d(2,2)
self.conv2 = Conv2d(32,64,kernel_size = 3,padding = 1)
self.bn2 = BatchNorm2d(64)
self.pooling2 = MaxPool2d(2,2)
self.extra_conv1 = Conv2d(32,32,kernel_size=3,padding=1)
def feature_extractor(self,x):
out = self.extra_conv1(x)
return out
def forward(self,x):
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
x = self.feature_extractor(x)
x = self.pooling1(x)
x = self.conv2(x)
x = self.bn2(x)
x = F.relu(x)
x = self.pooling2(x)
return x
model = MODEL()
summary(model,(3,64,64))
The code above is a simple example of extracting feature,In this case I have Total params: 28,832
.
Case 2 : define function outside of class
def feature_extractor(x):
in_channels= x.size(1)
extractor =Conv2d(in_channels,32,kernel_size = 3,padding = 1)
out = extractor(x)
return out
class MODEL(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = Conv2d(3,32,kernel_size = 3,padding = 1)
self.bn1 = BatchNorm2d(32)
self.pooling1 = MaxPool2d(2,2)
self.conv2 = Conv2d(32,64,kernel_size = 3,padding = 1)
self.bn2 = BatchNorm2d(64)
self.pooling2 = MaxPool2d(2,2)
self.extra_conv1 =feature_extractor
def forward(self,x):
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
x = self.extra_conv1(x)
x = self.pooling1(x)
x = self.conv2(x)
x = self.bn2(x)
x = F.relu(x)
x = self.pooling2(x)
return x
model = MODEL()
summary(model,(3,64,64))
In this I define feature_extractor
outside of class And I only get Total params: 19,584
, It seems that the feature_extractor
is not part of graph…
I know I could’ve done it in a better way(do it in forward
function),but just for curiosity,should I define the function inside of class,why or why not? Thanks in advance!