import torch
from torch import nn
class ADD1(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x + 1
class ADD2(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x + 2
class TEST(nn.Module):
def __init__(self):
super().__init__()
self.add = ADD1()
def forward(self, x):
return self.add(x)
x = torch.zeros(1)
test = torch.jit.script(TEST().eval())
print(test.graph)
print(test(x))
# tensor([1.])
new_add = torch.jit.script(ADD2().eval())
# test.add = new_add
# Traceback (most recent call last):
# File "test2.py", line 41, in <module>
# test.add = new_add
# File "/usr/local/lib/python3.8/site-packages/torch/jit/_script.py", line 791, in __setattr__
# self._modules[attr] = value
# File "/usr/local/lib/python3.8/site-packages/torch/jit/_script.py", line 244, in __setitem__
# self._c.setattr(k, v)
# RuntimeError: Expected a value of type '__torch__.ADD1 (of Python compilation unit at: 0x55b702ec2cb0)' for field 'add', but found '__torch__.ADD2 (of Python compilation unit at: 0x55b702ec2cb0)'
test.add._reconstruct(new_add._c)
print(test.graph)
print(test(x))
# tensor([1.])
How to replace test.add
with new_add
in torchscript? I tried test.add = new_add
and test.add._reconstruct(new_add._c)
, but neither of them works.
Thanks