Comparing the outputs is a valid method, but as you’ve already stated you would need to load the state_dict
from one model to the other, which might need some manual work especially if the layer names etc. change.
Tracing the model is also a good idea and you might be able to use torch.fx
for it.
Here is a small code example using your models, which traces both models, checks their module calls, and then compares the attributes of these modules:
class SimpleNet1(nn.Module):
def __init__(self, vocab_size):
super().__init__()
self.vocab_size = vocab_size
self.fc1 = nn.Linear(self.vocab_size, 4)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(4, 3)
self.relu2 = nn.ReLU()
self.fc3 = nn.Linear(3, 3)
self.relu3 = nn.ReLU()
self.out = nn.Linear(3, 2)
self.log_softmax = nn.LogSoftmax(dim=1)
def forward(self, X):
out = self.fc1(X)
out = self.relu1(out)
out = self.fc2(out)
out = self.relu2(out)
out = self.fc3(out)
out = self.relu3(out)
out = self.out(out)
log_probs = self.log_softmax(out)
return log_probs
class SimpleNet2(nn.Module):
def __init__(self, vocab_size):
super().__init__()
self.vocab_size = vocab_size
self.net = nn.Sequential(
nn.Linear(self.vocab_size, 4),
nn.ReLU(),
nn.Linear(4, 3),
nn.ReLU(),
nn.Linear(3, 3),
nn.ReLU(),
nn.Linear(3, 2),
nn.LogSoftmax(dim=1)
)
def forward(self, X):
log_probs = self.net(X)
return log_probs
vocab_size = 10
model1 = SimpleNet1(vocab_size)
model2 = SimpleNet2(vocab_size)
x = torch.randn(1, 10)
out1 = model1(x)
out2 = model2(x)
traced1 = torch.fx.symbolic_trace(model1)
print(traced1.graph.print_tabular())
# opcode name target args kwargs
# ----------- ----------- ----------- -------------- --------
# placeholder x X () {}
# call_module fc1 fc1 (x,) {}
# call_module relu1 relu1 (fc1,) {}
# call_module fc2 fc2 (relu1,) {}
# call_module relu2 relu2 (fc2,) {}
# call_module fc3 fc3 (relu2,) {}
# call_module relu3 relu3 (fc3,) {}
# call_module out out (relu3,) {}
# call_module log_softmax log_softmax (out,) {}
# output output output (log_softmax,) {}
# None
traced2 = torch.fx.symbolic_trace(model2)
print(traced2.graph.print_tabular())
# opcode name target args kwargs
# ----------- ------ -------- -------- --------
# placeholder x X () {}
# call_module net_0 net.0 (x,) {}
# call_module net_1 net.1 (net_0,) {}
# call_module net_2 net.2 (net_1,) {}
# call_module net_3 net.3 (net_2,) {}
# call_module net_4 net.4 (net_3,) {}
# call_module net_5 net.5 (net_4,) {}
# call_module net_6 net.6 (net_5,) {}
# call_module net_7 net.7 (net_6,) {}
# output output output (net_7,) {}
# None
# grab modules
ref_modules1 = dict(traced1.named_modules())
modules1 = []
for n in traced1.graph.nodes:
#print(n, n.op)
if n.op == "call_module":
modules1.append(ref_modules1[n.target])
print(modules1)
# [Linear(in_features=10, out_features=4, bias=True), ReLU(), Linear(in_features=4, out_features=3, bias=True), ReLU(), Linear(in_features=3, out_features=3, bias=True), ReLU(), Linear(in_features=3, out_features=2, bias=True), LogSoftmax(dim=1)]
ref_modules2 = dict(traced2.named_modules())
modules2 = []
for n in traced2.graph.nodes:
#print(n, n.op)
if n.op == "call_module":
modules2.append(ref_modules2[n.target])
print(modules2)
# [Linear(in_features=10, out_features=4, bias=True), ReLU(), Linear(in_features=4, out_features=3, bias=True), ReLU(), Linear(in_features=3, out_features=3, bias=True), ReLU(), Linear(in_features=3, out_features=2, bias=True), LogSoftmax(dim=1)]
# compare
assert len(modules1)==len(modules2), "size mismatch"
for m1, m2 in zip(modules1, modules2):
attributes1 = vars(m1)
attributes2 = vars(m2)
for attribute_key in attributes1:
#print(attribute_key)
if not attribute_key.startswith("_"):
a1 = attributes1[attribute_key]
a2 = attributes2[attribute_key]
if a1 != a2:
print(f"mismatch for modules: {m1} for attribute {attribute_key}")
else:
print(f"match for modules: {m1} for attribute {attribute_key}")
# match for modules: Linear(in_features=10, out_features=4, bias=True) for attribute training
# match for modules: Linear(in_features=10, out_features=4, bias=True) for attribute in_features
# match for modules: Linear(in_features=10, out_features=4, bias=True) for attribute out_features
# match for modules: ReLU() for attribute training
# match for modules: ReLU() for attribute inplace
# match for modules: Linear(in_features=4, out_features=3, bias=True) for attribute training
# match for modules: Linear(in_features=4, out_features=3, bias=True) for attribute in_features
# match for modules: Linear(in_features=4, out_features=3, bias=True) for attribute out_features
# match for modules: ReLU() for attribute training
# match for modules: ReLU() for attribute inplace
# match for modules: Linear(in_features=3, out_features=3, bias=True) for attribute training
# match for modules: Linear(in_features=3, out_features=3, bias=True) for attribute in_features
# match for modules: Linear(in_features=3, out_features=3, bias=True) for attribute out_features
# match for modules: ReLU() for attribute training
# match for modules: ReLU() for attribute inplace
# match for modules: Linear(in_features=3, out_features=2, bias=True) for attribute training
# match for modules: Linear(in_features=3, out_features=2, bias=True) for attribute in_features
# match for modules: Linear(in_features=3, out_features=2, bias=True) for attribute out_features
# match for modules: LogSoftmax(dim=1) for attribute training
# match for modules: LogSoftmax(dim=1) for attribute dim
Note that I did not implement functional API calls and if n.op == "call_module"
should fail in this case.
Functional API calls might be more tricky to check as you would need to check for the op as well as which parameters/buffers were initialized and used, so my example is not 100% reliable but might be a good starter.