So I was able to resolve my issue based on Omkar’s suggestion to look at https://github.com/pytorch/pytorch/issues/8637
I changed the following: Instead of binding forward
to the only_forward
method, I directly call it inside the forward
method:
if add_output:
self.output = af.InternalClassifier(input_size, self.expansion*channels, num_classes)
self.no_output = False
else:
self.output = None
#self.forward = self.only_forward
self.no_output = True
def forward(self, x):
if self.no_output:
return self.only_forward(x)
else:
fwd = self.layers[0](x) # conv layers
fwd = fwd + self.layers[1](x) # shortcut
return self.layers[2](fwd), 1, self.output(fwd) # output layers for this module
def only_output(self, x):
fwd = self.layers[0](x) # conv layers
fwd = fwd + self.layers[1](x) # shortcut
fwd = self.layers[2](fwd) # activation
out = self.output(fwd) # output layers for this module
return out
def only_forward(self, x):
fwd = self.layers[0](x) # conv layers
fwd = fwd + self.layers[1](x) # shortcut
return self.layers[2](fwd), 0, None # activation