Changing the activation function during inference

Hello team,
I have got a trained model.pt file for a model that uses RELU activation. I want to change it to RELU6 during inference. How can I do that?

Any leads would be really helpful.

Thanks and Regards,
pyberry

If this activation function is defined as a module, you could replace it directly, e.g. via:

model.act = nn.ReLU6()

assuming that all instances of self.act should be changed.
On the other hand, if the functional API was used via e.g. F.relu, then you could write a custom model and override the forward method.

Thanks @ptrblck .
The torchscript we have doesn’t use model.act.
So we may have to use the second option. Can you please point me to an example of writing a custom model to override the forward method?

Sure, here is a small example:

class MyModel(nn.Module):
    def __init__(self, features):
        super().__init__()
        self.lin1 = nn.Linear(features, features)
        self.lin2 = nn.Linear(features, features)
        
    def forward(self, x):
        x = self.lin1(x)
        x = F.relu(x)
        x = self.lin2(x)
        return x

class MyNewModel(MyModel):
    def __init__(self, features):
        super().__init__(features)
        
    def forward(self, x):
        x = self.lin1(x)
        x = F.gelu(x)
        x = self.lin2(x)
        return x


model = MyModel(features=10)
x = torch.randn(1, 10)
out = model(x)

new_model = MyNewModel(features=10)
out = model(x)

Thanks a lot @ptrblck for your support.
Please see the issue below.

Code:

FILE = "fmnist_scripted.pt"
scripted_model = torch.jit.load(FILE)

for i in scripted_model.named_modules():
    print (i)

myfc3 = nn.Linear(in_features=120, out_features=10)
scripted_fc3 = torch.jit.script(myfc3)
scripted_model.fc3 = scripted_fc3

Output:

('', RecursiveScriptModule(
  original_name=FashionCNN
  (layer1): RecursiveScriptModule(
    original_name=Sequential
    (0): RecursiveScriptModule(original_name=Conv2d)
    (1): RecursiveScriptModule(original_name=BatchNorm2d)
    (2): RecursiveScriptModule(original_name=ReLU)
    (3): RecursiveScriptModule(original_name=MaxPool2d)
  )
  (layer2): RecursiveScriptModule(
    original_name=Sequential
    (0): RecursiveScriptModule(original_name=Conv2d)
    (1): RecursiveScriptModule(original_name=BatchNorm2d)
    (2): RecursiveScriptModule(original_name=ReLU)
    (3): RecursiveScriptModule(original_name=MaxPool2d)
  )
  (fc1): RecursiveScriptModule(original_name=Linear)
  (drop): RecursiveScriptModule(original_name=Dropout2d)
  (fc2): RecursiveScriptModule(original_name=Linear)
  (fc3): RecursiveScriptModule(original_name=Linear)
))
('layer1', RecursiveScriptModule(
  original_name=Sequential
  (0): RecursiveScriptModule(original_name=Conv2d)
  (1): RecursiveScriptModule(original_name=BatchNorm2d)
  (2): RecursiveScriptModule(original_name=ReLU)
  (3): RecursiveScriptModule(original_name=MaxPool2d)
))
('layer1.0', RecursiveScriptModule(original_name=Conv2d))
('layer1.1', RecursiveScriptModule(original_name=BatchNorm2d))
('layer1.2', RecursiveScriptModule(original_name=ReLU))
('layer1.3', RecursiveScriptModule(original_name=MaxPool2d))
('layer2', RecursiveScriptModule(
  original_name=Sequential
  (0): RecursiveScriptModule(original_name=Conv2d)
  (1): RecursiveScriptModule(original_name=BatchNorm2d)
  (2): RecursiveScriptModule(original_name=ReLU)
  (3): RecursiveScriptModule(original_name=MaxPool2d)
))
('layer2.0', RecursiveScriptModule(original_name=Conv2d))
('layer2.1', RecursiveScriptModule(original_name=BatchNorm2d))
('layer2.2', RecursiveScriptModule(original_name=ReLU))
('layer2.3', RecursiveScriptModule(original_name=MaxPool2d))
('fc1', RecursiveScriptModule(original_name=Linear))
('drop', RecursiveScriptModule(original_name=Dropout2d))
('fc2', RecursiveScriptModule(original_name=Linear))
('fc3', RecursiveScriptModule(original_name=Linear))

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-31-3c19b7b057b1> in <module>
     13 myfc3 = nn.Linear(in_features=120, out_features=10)
     14 scripted_fc3 = torch.jit.script(myfc3)
---> 15 scripted_model.fc3 = scripted_fc3
     16 #############################################################################################
     17 # Carving out & saving layers one-by-one

~/pyvenv/lib/python3.6/site-packages/torch/jit/_script.py in __setattr__(self, attr, value)
    672 
    673             if attr in self._modules:
--> 674                 self._modules[attr] = value
    675             elif self._c.hasattr(attr):
    676                 self._c.setattr(attr, value)

~/pyvenv/lib/python3.6/site-packages/torch/jit/_script.py in __setitem__(self, k, v)
    221         # otherwise it's illegal and we throw error.
    222         if isinstance(v, ScriptModule):
--> 223             self._c.setattr(k, v)
    224             self._python_modules[k] = v
    225         else:

RuntimeError: Expected a value of type '__torch__.torch.nn.modules.linear.___torch_mangle_4.Linear (of Python compilation unit at: 0x10687bf0)' for field 'fc3', but found '__torch__.torch.nn.modules.linear.___torch_mangle_4.Linear (of Python compilation unit at: 0x3965610)'

I don’t know if manipulating scripted models is possible at all, so we would need to wait for a JIT expert.

1 Like