Setting up a Siamese Network model

I’m trying my hands on Siamese Network models. After searching for some examples, this seems to be the common way to set up the model

class SiameseNetwork(nn.Module):

	def __init__(self, core_model):
    	super(SiameseNetwork, self).__init__()
    	# Define layers
    	self.layer1 = ...
    	self.layer2 = ...
    	...

    def forward_one(input):
    	X = self.layer1(input)
    	X = self.layer2(X)
    	...
    	return X
    
    def forward(self, input1, input2):
        output1 = self.layers(input1)
    	output2 = self.layers(input2)
    	return output1, output2

Here, the layers of the parallel networks are defined in the model itself. And this seems to work just fine. However, now I was trying to make the model easier to reduce by “extracting” the parallel model as follows:

class CoreModel(nn.Module):
	
	def __init__(self):
    	super(CoreModel, self).__init__()
		# Define layers 
		self.layer1 = ...
    	self.layer2 = ...
    	...

	def forward(self, input):
		X = self.layer1(input)
    	X = self.layer2(X)
    	...
    	return X


class SiameseNetwork(nn.Module):

	def __init__(self, core_model):
    	super(SiameseNetwork, self).__init__()
    	self.core_model = core_model
    
    def forward(self, input1, input2):
        output1 = self.core_model(input1)
    	output2 = self.core_model(input2)
    	return output1, output2


core_model = CoreModel()
model = SiameseNetwork(core_model)

The problem is that when I call loss.backward() during training it says that I cannot call backward() twice. I assume that’s because I wrapped the layers in a nn.Module.

Is there a right way to define the “core model” on its own and then just give it to the Siamese Network model as parameter?

1 Like

That’s a weird issue.
I tried to reproduce it using your code snippets, but they seem to work fine:

class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        # Define layers
        self.layer1 = nn.Linear(1, 1)
        self.layer2 = nn.Linear(1, 1)

    def layers(self, input):
        X = self.layer1(input)
        X = self.layer2(X)
        return X
    
    def forward(self, input1, input2):
        output1 = self.layers(input1)
        output2 = self.layers(input2)
        return output1, output2

model = SiameseNetwork()
criterion = nn.MSELoss()

x1 = torch.randn(1, 1)
x2 = torch.randn(1, 1)
out1, out2 = model(x1, x2)
loss = criterion(out1, out2)
loss.backward()

#############################

class CoreModel(nn.Module):
    def __init__(self):
        super(CoreModel, self).__init__()
        # Define layers 
        self.layer1 = nn.Linear(1, 1)
        self.layer2 = nn.Linear(1, 1)

    def forward(self, input):
        X = self.layer1(input)
        X = self.layer2(X)
        return X

class SiameseNetwork(nn.Module):
    def __init__(self, core_model):
        super(SiameseNetwork, self).__init__()
        self.core_model = core_model
    
    def forward(self, input1, input2):
        output1 = self.core_model(input1)
        output2 = self.core_model(input2)
        return output1, output2

core_model = CoreModel()
model = SiameseNetwork(core_model)

criterion = nn.MSELoss()

x1 = torch.randn(1, 1)
x2 = torch.randn(1, 1)
out1, out2 = model(x1, x2)
loss = criterion(out1, out2)
loss.backward()

Could you please check, what might be different between your code any my simple reproduction approach?

2 Likes

@ptrblck thanks! Your examples for Variant A and Variant B work for me as well. When I use my “core model” (with embedding layer, RNN layer, linear layers) it still only works with A. All I did was to put the code from CoreModel directly into the SiameseNetwork class. I simply cannot see the difference at the moment.

I probably will need to slowly extend the basic work Variant B to see where it begins to break down.

EDIT: Yeah, slowly building up the basic network of Variant B did the trick. I cannot really tell where I made an error in the first place. I assume(!) that I didn’t correctly re-initialize the hidden state of the RNN layer at the right time(s).

2 Likes

Sounds like good debugging! :slight_smile:
Please, let me know, if it’s still not working or if you get stuck while debugging.

1 Like