Can I (reliably) check if 2 classes implement the same model?

Is there a way to check / decide if 2 network models are identical despite a different implementation. For example, let’s assume I have the following two classes SimpleNet1 and SimpleNet2.

class SimpleNet1(nn.Module):
    
	def __init__(self, vocab_size):
    	super().__init__()
    	self.vocab_size = vocab_size
    	self.fc1 = nn.Linear(self.vocab_size, 4)
    	self.relu1 = nn.ReLU()
    	self.fc2 = nn.Linear(4, 3)
    	self.relu2 = nn.ReLU()
    	self.fc3 = nn.Linear(3, 3)
    	self.relu3 = nn.ReLU()
    	self.out = nn.Linear(3, 2)   	 
    	self.log_softmax = nn.LogSoftmax(dim=1)
   	 
	def forward(self, X):
    	out = self.fc1(X)
    	out = self.relu1(out)
    	out = self.fc2(out)
    	out = self.relu2(out)
    	out = self.fc3(out)
    	out = self.relu3(out)
    	out = self.out(out)
    	log_probs = self.log_softmax(out)
    	return log_probs
class SimpleNet2(nn.Module):
    
	def __init__(self, vocab_size):
    	super().__init__()
    	self.vocab_size = vocab_size
   	 
    	self.net = nn.Sequential(
        	nn.Linear(self.vocab_size, 4),
        	nn.ReLU(),
        	nn.Linear(4, 3),
        	nn.ReLU(),
        	nn.Linear(3, 3),
        	nn.ReLU(),
        	nn.Linear(3, 2),
        	nn.LogSoftmax(dim=1)
    	)
   	 
	def forward(self, X):
    	log_probs = self.net(X)
    	return log_probs

Both models are arguably identical as they use the exact same layer in the same order etc. However, the syntax is different, as SimpleNet2 uses nn.Sequential to yield a simpler code. And of course, for a more complex model, the variety of how to exactly implement it increases.

Right now, I was wondering about 2 alternatives:

  • Compare the output for the same input. If two models are indeed the same, they should yield the same output for the same input. However, this would require that all corresponding layers are initialized the same way, and it doesn’t seem easy to reliably identify the corresponding layers. Things like nn.Dropout also seem to cause problems.
  • Use TorchScript to compare models. From what I understand – but that’s not too much unfortunately – it is possible to convert a model to TorchScript via tracing and/or scripting. It looks to me that this would abstract away such syntactic sugar like nn.Sequential, thus making it (hopefully) easier to check if two models are identical.

Any comments on these two alternatives or any other ideas are very much appreciated.

Comparing the outputs is a valid method, but as you’ve already stated you would need to load the state_dict from one model to the other, which might need some manual work especially if the layer names etc. change.

Tracing the model is also a good idea and you might be able to use torch.fx for it.
Here is a small code example using your models, which traces both models, checks their module calls, and then compares the attributes of these modules:

class SimpleNet1(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.vocab_size = vocab_size
        self.fc1 = nn.Linear(self.vocab_size, 4)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(4, 3)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(3, 3)
        self.relu3 = nn.ReLU()
        self.out = nn.Linear(3, 2)   	 
        self.log_softmax = nn.LogSoftmax(dim=1)
   	
    def forward(self, X):
    	out = self.fc1(X)
    	out = self.relu1(out)
    	out = self.fc2(out)
    	out = self.relu2(out)
    	out = self.fc3(out)
    	out = self.relu3(out)
    	out = self.out(out)
    	log_probs = self.log_softmax(out)
    	return log_probs
    

class SimpleNet2(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.vocab_size = vocab_size   	 
        self.net = nn.Sequential(
        	nn.Linear(self.vocab_size, 4),
        	nn.ReLU(),
        	nn.Linear(4, 3),
        	nn.ReLU(),
        	nn.Linear(3, 3),
        	nn.ReLU(),
        	nn.Linear(3, 2),
        	nn.LogSoftmax(dim=1)
    	)
   	 
    def forward(self, X):
        log_probs = self.net(X)
        return log_probs
    
vocab_size = 10
model1 = SimpleNet1(vocab_size)
model2 = SimpleNet2(vocab_size)

x = torch.randn(1, 10)
out1 = model1(x)
out2 = model2(x)

traced1 = torch.fx.symbolic_trace(model1)
print(traced1.graph.print_tabular())
# opcode       name         target       args            kwargs
# -----------  -----------  -----------  --------------  --------
# placeholder  x            X            ()              {}
# call_module  fc1          fc1          (x,)            {}
# call_module  relu1        relu1        (fc1,)          {}
# call_module  fc2          fc2          (relu1,)        {}
# call_module  relu2        relu2        (fc2,)          {}
# call_module  fc3          fc3          (relu2,)        {}
# call_module  relu3        relu3        (fc3,)          {}
# call_module  out          out          (relu3,)        {}
# call_module  log_softmax  log_softmax  (out,)          {}
# output       output       output       (log_softmax,)  {}
# None

traced2 = torch.fx.symbolic_trace(model2)
print(traced2.graph.print_tabular())
# opcode       name    target    args      kwargs
# -----------  ------  --------  --------  --------
# placeholder  x       X         ()        {}
# call_module  net_0   net.0     (x,)      {}
# call_module  net_1   net.1     (net_0,)  {}
# call_module  net_2   net.2     (net_1,)  {}
# call_module  net_3   net.3     (net_2,)  {}
# call_module  net_4   net.4     (net_3,)  {}
# call_module  net_5   net.5     (net_4,)  {}
# call_module  net_6   net.6     (net_5,)  {}
# call_module  net_7   net.7     (net_6,)  {}
# output       output  output    (net_7,)  {}
# None

# grab modules
ref_modules1 = dict(traced1.named_modules())
modules1 = []
for n in traced1.graph.nodes:
    #print(n, n.op)
    if n.op == "call_module":
        modules1.append(ref_modules1[n.target])
print(modules1)
# [Linear(in_features=10, out_features=4, bias=True), ReLU(), Linear(in_features=4, out_features=3, bias=True), ReLU(), Linear(in_features=3, out_features=3, bias=True), ReLU(), Linear(in_features=3, out_features=2, bias=True), LogSoftmax(dim=1)]

ref_modules2 = dict(traced2.named_modules())
modules2 = []
for n in traced2.graph.nodes:
    #print(n, n.op)
    if n.op == "call_module":
        modules2.append(ref_modules2[n.target])
print(modules2)
# [Linear(in_features=10, out_features=4, bias=True), ReLU(), Linear(in_features=4, out_features=3, bias=True), ReLU(), Linear(in_features=3, out_features=3, bias=True), ReLU(), Linear(in_features=3, out_features=2, bias=True), LogSoftmax(dim=1)]

# compare
assert len(modules1)==len(modules2), "size mismatch"
for m1, m2 in zip(modules1, modules2):
    attributes1 = vars(m1)
    attributes2 = vars(m2)
    for attribute_key in attributes1:
        #print(attribute_key)
        if not attribute_key.startswith("_"):
            a1 = attributes1[attribute_key]
            a2 = attributes2[attribute_key]
            if a1 != a2:
                print(f"mismatch for modules: {m1} for attribute {attribute_key}")
            else:
                print(f"match for modules: {m1} for attribute {attribute_key}")
        
# match for modules: Linear(in_features=10, out_features=4, bias=True) for attribute training
# match for modules: Linear(in_features=10, out_features=4, bias=True) for attribute in_features
# match for modules: Linear(in_features=10, out_features=4, bias=True) for attribute out_features
# match for modules: ReLU() for attribute training
# match for modules: ReLU() for attribute inplace
# match for modules: Linear(in_features=4, out_features=3, bias=True) for attribute training
# match for modules: Linear(in_features=4, out_features=3, bias=True) for attribute in_features
# match for modules: Linear(in_features=4, out_features=3, bias=True) for attribute out_features
# match for modules: ReLU() for attribute training
# match for modules: ReLU() for attribute inplace
# match for modules: Linear(in_features=3, out_features=3, bias=True) for attribute training
# match for modules: Linear(in_features=3, out_features=3, bias=True) for attribute in_features
# match for modules: Linear(in_features=3, out_features=3, bias=True) for attribute out_features
# match for modules: ReLU() for attribute training
# match for modules: ReLU() for attribute inplace
# match for modules: Linear(in_features=3, out_features=2, bias=True) for attribute training
# match for modules: Linear(in_features=3, out_features=2, bias=True) for attribute in_features
# match for modules: Linear(in_features=3, out_features=2, bias=True) for attribute out_features
# match for modules: LogSoftmax(dim=1) for attribute training
# match for modules: LogSoftmax(dim=1) for attribute dim

Note that I did not implement functional API calls and if n.op == "call_module" should fail in this case.
Functional API calls might be more tricky to check as you would need to check for the op as well as which parameters/buffers were initialized and used, so my example is not 100% reliable but might be a good starter.

1 Like

@ptrblck thanks so much for the example code! I just started looking in this TorchScript idea.

Yes, unfortunately, I cannot guarantee that the layers have the same name across different classes. I basically trying to figure out it I can automatically check my students submission against my reference solution :).

1 Like

@ptrblck I’ve modified your example to include a more complex network containing an nn.LSTM. The main difference is here that in one model I use batch_first=True and in the other model batch_first=False for the nn.LSTM. This mismatch is correctly identified in the final output. However, since I also use transpose() in the second model and change out[:, -1, :] (for batch_first=True) to out[-1] (for batch_first=False), both models are in fact identical.

In the loop

for n in traced1.graph.nodes:
    print(n, n.op, n.args)
    ...

I can see the use of transpose() and the different requirements for slicing out, but now I can’t really see how I can use the trace (alone) to check for identity. My current thought right now is to use the traced to identify match modules so I can copy the state_dict of one module from 1st network to its counterpart in the 2ndr module. There would still be the problem that the two nn.LSTM layers would mismatch because of batch_first; maybe I can consider only the a subset of attributes and ignore things like batch_first. But it feels to become more and more shaky :slight_smile:

Might there something obvious I’m missing here?

class SimpleLSTM1(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, output_size):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, batch_first=True)
        self.fc1 = nn.Linear(hidden_size, 64)
        self.relu1 = nn.ReLU()
        self.out = nn.Linear(64, output_size)
        self.log_softmax = nn.LogSoftmax(dim=1)

    def forward(self, X):
        out = self.embed(X)
        out, (h, c) = self.lstm(out)
        out = self.fc1(out[:, -1, :])
        out = self.relu1(out)
        out = self.out(out)
        log_probs = self.log_softmax(out)
        return log_probs
        
        
class SimpleLSTM2(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, output_size):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, batch_first=False)
        self.net = nn.Sequential(
            nn.Linear(hidden_size, 64),
            nn.ReLU(),
            nn.Linear(64, output_size),
            nn.LogSoftmax(dim=1)
        )

    def forward(self, X):
        out = self.embed(X)
        out = out.transpose(0, 1)
        out, (h, c) = self.lstm(out)
        log_probs = self.net(out[-1])
        return log_probs
        
        
vocab_size, embed_size, hidden_size, output_size = 100, 300, 512, 3

model1 = SimpleLSTM1(vocab_size, embed_size, hidden_size, output_size)
model2 = SimpleLSTM2(vocab_size, embed_size, hidden_size, output_size)

batch_size, seq_len = 5, 20
x = torch.randint(vocab_size, (batch_size, seq_len))

out1 = model1(x)
out2 = model2(x)        
        
        
        
traced1 = torch.fx.symbolic_trace(model1)
print(traced1.graph.print_tabular())        
# opcode         name         target                       args                                                               kwargs
# -------------  -----------  ---------------------------  -----------------------------------------------------------------  --------
# placeholder    x            X                            ()                                                                 {}
# call_module    embed        embed                        (x,)                                                               {}
# call_module    lstm         lstm                         (embed,)                                                           {}
# call_function  getitem      <built-in function getitem>  (lstm, 0)                                                          {}
# call_function  getitem_1    <built-in function getitem>  (lstm, 1)                                                          {}
# call_function  getitem_2    <built-in function getitem>  (getitem_1, 0)                                                     {}
# call_function  getitem_3    <built-in function getitem>  (getitem_1, 1)                                                     {}
# call_function  getitem_4    <built-in function getitem>  (getitem, (slice(None, None, None), -1, slice(None, None, None)))  {}
# call_module    fc1          fc1                          (getitem_4,)                                                       {}
# call_module    relu1        relu1                        (fc1,)                                                             {}
# call_module    out          out                          (relu1,)                                                           {}
# call_module    log_softmax  log_softmax                  (out,)                                                             {}
# output         output       output                       (log_softmax,)                                                     {}        
        
traced2 = torch.fx.symbolic_trace(model2)
print(traced2.graph.print_tabular())   
# opcode         name         target                       args                                                               kwargs
# -------------  -----------  ---------------------------  -----------------------------------------------------------------  --------
# placeholder    x            X                            ()                                                                 {}
# call_module    embed        embed                        (x,)                                                               {}
# call_module    lstm         lstm                         (embed,)                                                           {}
# call_function  getitem      <built-in function getitem>  (lstm, 0)                                                          {}
# call_function  getitem_1    <built-in function getitem>  (lstm, 1)                                                          {}
# call_function  getitem_2    <built-in function getitem>  (getitem_1, 0)                                                     {}
# call_function  getitem_3    <built-in function getitem>  (getitem_1, 1)                                                     {}
# call_function  getitem_4    <built-in function getitem>  (getitem, (slice(None, None, None), -1, slice(None, None, None)))  {}
# call_module    fc1          fc1                          (getitem_4,)                                                       {}
# call_module    relu1        relu1                        (fc1,)                                                             {}
# call_module    out          out                          (relu1,)                                                           {}
# call_module    log_softmax  log_softmax                  (out,)                                                             {}
# output         output       output                       (log_softmax,)                                                     {}        
        
        
assert len(modules1)==len(modules2), "size mismatch"

for m1, m2 in zip(modules1, modules2):
    attributes1 = vars(m1)
    attributes2 = vars(m2)
    for attribute_key in attributes1:
        if not attribute_key.startswith("_"):
            a1 = attributes1[attribute_key]
            a2 = attributes2[attribute_key]
            if a1 != a2:
                print(f"mismatch for modules: {m1} for attribute {attribute_key}")
            else:
                print(f"match for modules: {m1} for attribute {attribute_key}")
                
  
# ...              
# match for modules: LSTM(300, 512, batch_first=True) for attribute bias
# mismatch for modules: LSTM(300, 512, batch_first=True) for attribute batch_first
# match for modules: LSTM(300, 512, batch_first=True) for attribute dropout                   
# ...