I am new to programming and unable to figure out something. Thanks for help in advance.
So the question is
when we make a class for model
class model(nn.Module):
def __init__(self,):
super(model, self).__init__()
self.featureextract = fext
self.classifier = classifier
def forward(self, input):
x = self.classifier(input)
return x
def feature_extract(self,input):
x = self.featureextract(input)
return x
So in such a case where I have two functions, how the backpropagation will happen, will the parameters of model contain only the forward part or both. And in case only both, how can I make such a model where the features can be extracted separately, classifier can be made separately and then backpropagation happens on the complete model.
The parameters of the model are not related to the definition of forward
or some other functions, but are defined inside __init__
.
The computation graph will be created dynamically based on the forward execution.
I.e. you don’t even necessarily need an nn.Module
to wrap your parameters, but could also work directly with some parameters:
x = torch.randn(1, 1)
w = nn.Parameter(torch.randn(1, 1))
y = x * w
y.backward()
print(w.grad)
This is of course not the recommended way to write your complicated models, but might give you some information about the work flow of the computation graph and Autograd.
The disadvantage of your current approach is, that you should usually call the model directly:
output = model(input)
instead of the forward
method or some other methods:
output1 = model.forward(input)
output2 = model.feature_extraction(input)
The former approach will call the internal __call__
method, which will e.g. register hooks (if necessary) and then call forward
.
A possible workaround would be to define forward
with a flag to decide which path to chose:
class model(nn.Module):
def __init__(self,):
super(model, self).__init__()
self.featureextract = fext
self.classifier = classifier
def forward(self, input, path='class'):
if path == 'class':
x = self.classifier(input)
else:
x = self.feature_extractor(input)
return x
def feature_extract(self,input):
x = self.featureextract(input)
return x
Thanks a lot sir I got it now.