Hello team,
I have got a trained model.pt file for a model that uses RELU activation. I want to change it to RELU6 during inference. How can I do that?
Any leads would be really helpful.
Thanks and Regards,
pyberry
Hello team,
I have got a trained model.pt file for a model that uses RELU activation. I want to change it to RELU6 during inference. How can I do that?
Any leads would be really helpful.
Thanks and Regards,
pyberry
If this activation function is defined as a module, you could replace it directly, e.g. via:
model.act = nn.ReLU6()
assuming that all instances of self.act
should be changed.
On the other hand, if the functional API was used via e.g. F.relu
, then you could write a custom model and override the forward
method.
Thanks @ptrblck .
The torchscript we have doesn’t use model.act.
So we may have to use the second option. Can you please point me to an example of writing a custom model to override the forward method?
Sure, here is a small example:
class MyModel(nn.Module):
def __init__(self, features):
super().__init__()
self.lin1 = nn.Linear(features, features)
self.lin2 = nn.Linear(features, features)
def forward(self, x):
x = self.lin1(x)
x = F.relu(x)
x = self.lin2(x)
return x
class MyNewModel(MyModel):
def __init__(self, features):
super().__init__(features)
def forward(self, x):
x = self.lin1(x)
x = F.gelu(x)
x = self.lin2(x)
return x
model = MyModel(features=10)
x = torch.randn(1, 10)
out = model(x)
new_model = MyNewModel(features=10)
out = model(x)
Thanks a lot @ptrblck for your support.
Please see the issue below.
Code:
FILE = "fmnist_scripted.pt"
scripted_model = torch.jit.load(FILE)
for i in scripted_model.named_modules():
print (i)
myfc3 = nn.Linear(in_features=120, out_features=10)
scripted_fc3 = torch.jit.script(myfc3)
scripted_model.fc3 = scripted_fc3
Output:
('', RecursiveScriptModule(
original_name=FashionCNN
(layer1): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(original_name=Conv2d)
(1): RecursiveScriptModule(original_name=BatchNorm2d)
(2): RecursiveScriptModule(original_name=ReLU)
(3): RecursiveScriptModule(original_name=MaxPool2d)
)
(layer2): RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(original_name=Conv2d)
(1): RecursiveScriptModule(original_name=BatchNorm2d)
(2): RecursiveScriptModule(original_name=ReLU)
(3): RecursiveScriptModule(original_name=MaxPool2d)
)
(fc1): RecursiveScriptModule(original_name=Linear)
(drop): RecursiveScriptModule(original_name=Dropout2d)
(fc2): RecursiveScriptModule(original_name=Linear)
(fc3): RecursiveScriptModule(original_name=Linear)
))
('layer1', RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(original_name=Conv2d)
(1): RecursiveScriptModule(original_name=BatchNorm2d)
(2): RecursiveScriptModule(original_name=ReLU)
(3): RecursiveScriptModule(original_name=MaxPool2d)
))
('layer1.0', RecursiveScriptModule(original_name=Conv2d))
('layer1.1', RecursiveScriptModule(original_name=BatchNorm2d))
('layer1.2', RecursiveScriptModule(original_name=ReLU))
('layer1.3', RecursiveScriptModule(original_name=MaxPool2d))
('layer2', RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(original_name=Conv2d)
(1): RecursiveScriptModule(original_name=BatchNorm2d)
(2): RecursiveScriptModule(original_name=ReLU)
(3): RecursiveScriptModule(original_name=MaxPool2d)
))
('layer2.0', RecursiveScriptModule(original_name=Conv2d))
('layer2.1', RecursiveScriptModule(original_name=BatchNorm2d))
('layer2.2', RecursiveScriptModule(original_name=ReLU))
('layer2.3', RecursiveScriptModule(original_name=MaxPool2d))
('fc1', RecursiveScriptModule(original_name=Linear))
('drop', RecursiveScriptModule(original_name=Dropout2d))
('fc2', RecursiveScriptModule(original_name=Linear))
('fc3', RecursiveScriptModule(original_name=Linear))
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-31-3c19b7b057b1> in <module>
13 myfc3 = nn.Linear(in_features=120, out_features=10)
14 scripted_fc3 = torch.jit.script(myfc3)
---> 15 scripted_model.fc3 = scripted_fc3
16 #############################################################################################
17 # Carving out & saving layers one-by-one
~/pyvenv/lib/python3.6/site-packages/torch/jit/_script.py in __setattr__(self, attr, value)
672
673 if attr in self._modules:
--> 674 self._modules[attr] = value
675 elif self._c.hasattr(attr):
676 self._c.setattr(attr, value)
~/pyvenv/lib/python3.6/site-packages/torch/jit/_script.py in __setitem__(self, k, v)
221 # otherwise it's illegal and we throw error.
222 if isinstance(v, ScriptModule):
--> 223 self._c.setattr(k, v)
224 self._python_modules[k] = v
225 else:
RuntimeError: Expected a value of type '__torch__.torch.nn.modules.linear.___torch_mangle_4.Linear (of Python compilation unit at: 0x10687bf0)' for field 'fc3', but found '__torch__.torch.nn.modules.linear.___torch_mangle_4.Linear (of Python compilation unit at: 0x3965610)'
I don’t know if manipulating scripted models is possible at all, so we would need to wait for a JIT expert.