I am trying to translate an implementation of a Siamese neural network from Keras (tensorflow backend) to PyTorch.
This seemed to have worked quite smoothly but to my surprise the PyTorch model keeps an almost constant loss and doesn’t learn anything. I assume I am missing some PyTorch details.
I am trying to learn PyTorch but I think I am missing a crucial detail here which I don’t spot. The training loop and network principally works just the loss doesn’t reduce it stays more or less constant.
Maybe someone can help me with this?
My Keras network which works fine
def cosine_distance(vects):
x, y = vects
x = K.l2_normalize(x, axis=-1)
y = K.l2_normalize(y, axis=-1)
return -K.mean(x * y, axis=-1, keepdims=True)
def create_base_network(input_shape):
'''Base network to be shared (eq. to feature extraction).
'''
input = Input(name='siamese_in', shape=input_shape)
fwd_lstm = LSTM(512, name='fwdRNN', recurrent_dropout=.2) (input)
bwd_lstm = LSTM(512, name='bwdRNN', recurrent_dropout=.2, go_backwards=True) (input)
x = Concatenate(name='rnn_concat')([fwd_lstm, bwd_lstm])
model = Model(input, x)
return model
input_shape=X_train_A.shape[1:]
base_network = create_base_network(input_shape)
input_a = Input(shape=input_shape)
input_b = Input(shape=input_shape)
# because we re-use the same instance `base_network`,
# the weights of the network
# will be shared across the two branches
vector_1 = base_network(input_a)
vector_2 = base_network(input_b)
x3 = Subtract()([vector_1, vector_2])
x3 = Multiply()([x3, x3])
x1_ = Multiply(name="square_vec1")([vector_1, vector_1])
x2_ = Multiply(name="square_vec2")([vector_2, vector_2])
x4 = Subtract(name="subtract")([x1_, x2_])
x5 = Lambda(cosine_distance, output_shape=eucl_dist_output_shape)([vector_1, vector_2])
conc = Concatenate(axis=-1)([x5, x4, x3])
x = Dense(512, activation="relu")(conc)
x = Dense(256, activation="relu")(x)
x = Dense(128, activation="relu")(x)
out = Dense(1, activation="sigmoid")(x)
model = Model([input_a, input_b], out)
....
model.compile(loss='mse', optimizer='adam')
for i in range(0, 30):
batch_size=512
model.fit([X_train_A, X_train_B], y_train,
batch_size=batch_size,
shuffle=True,
epochs=1)
My translated counter part looks like this:
class SiameseNetwork(nn.Module):
def __init__(self, input_dim, rnn_hidden_dim):
super(SiameseNetwork, self).__init__()
self.rnn_hidden = rnn_hidden_dim
self.rnn = nn.LSTM(input_dim,
hidden_size=self.rnn_hidden,
num_layers=1,
batch_first=True,
bidirectional=True)
self.cos = nn.CosineSimilarity(dim=1)
self.out = nn.Sequential(
nn.Linear(rnn_hidden_dim*4 + 1, 512),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 1),
nn.Sigmoid()
)
def forward_input_branch(self, x):
hidden, last = self.rnn(x)
h_fwd = last[0][0]
h_bwd = last[1][0]
return torch.cat((h_fwd, h_bwd), -1)
def transform_branch_output(self, x1, x2):
# create representation as plain
# substraction, substraction of squared representations
# and cosine between representations
# plain substract - squared
x3 = x1 - x2
x3 = x3 * x3
# squared substract
x4 = x1*x1 - x2*x2
# cosine
x5 = self.cos(x1, x2)
x5 = x5.view(x5.size()[0], 1)
# concat all
concat = torch.cat((x5, x4, x3), -1)
return concat
def forward(self, input1, input2):
x1 = self.forward_input_branch(input1)
x2 = self.forward_input_branch(input2)
combo = self.transform_branch_output(x1, x2)
output = self.out(combo)
return output
# put the net on the GPU
net = SiameseNetwork(len(selection), 512).cuda()
loss_fn = torch.nn.MSELoss(reduction='mean')
optimizer = optim.Adam(net.parameters())
for epoch in range(0, 20):
for i, data in enumerate(train_dataloader):
A, B, labels = data
A, B, labels = A.cuda(), B.cuda(), labels.cuda()
output = net(A, B)
loss = loss_fn(output.squeeze(), labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("Epoch: {} - Loss {}\n".format(epoch, loss.item()))