Hi everyone! I’ve been working on implementing a seq2seq model with self-attention in pytorch. It’s been giving me some trouble for a few days now. Essentially, the model always learns to repeat the same word over and over. I’ve noticed that the encoder weights for the hidden state are always 0. However, the hidden state bias has non-zero gradients Does anyone have any idea what’s going on here?
class EncoderGRU(nn.Module):
def __init__(self, input_size, hidden_size, bi_dir=True, num_layers=1):
super(EncoderGRU, self).__init__()
self.encoder_gru = nn.GRU(input_size, hidden_size, bidirectional=bi_dir, num_layers=num_layers)
self.encoder_gru.apply(self.init_weights)
self.softmax = nn.Softmax(dim=1)
self.tanh = nn.Tanh()
self.sigmoid = nn.Sigmoid()
self.hidden_size = hidden_size
bi_dir_hidden_size = hidden_size * 2
self.encoder_att_linear = nn.Linear(bi_dir_hidden_size, bi_dir_hidden_size)
torch.nn.init.xavier_uniform_(self.encoder_att_linear.weight)
self.encoder_att_g_gate = nn.Linear(bi_dir_hidden_size * 2, bi_dir_hidden_size)
torch.nn.init.xavier_uniform_(self.encoder_att_g_gate.weight)
self.encoder_att_f_gate = nn.Linear(bi_dir_hidden_size * 2, bi_dir_hidden_size)
torch.nn.init.xavier_uniform_(self.encoder_att_f_gate.weight)
@staticmethod
def init_weights(w):
if type(w) == nn.GRUCell:
nn.init.xavier_uniform_(w)
def calc_encoder_attn(self, x):
attn = torch.zeros([x.shape[0], x.shape[1], self.hidden_size * 2])
for batch_idx, batch in enumerate(x):
for idx in range(0, batch.shape[0]):
cur_hidden_state = torch.index_select(batch, 0, torch.tensor(idx))
attn_layer = torch.t(self.encoder_att_linear(cur_hidden_state))
attn_layer = self.softmax(torch.matmul(torch.t(attn_layer), torch.t(batch)))
attn_layer = torch.matmul(attn_layer, batch)
attn_layer = torch.cat((attn_layer, cur_hidden_state), dim=1)
g_gate = self.sigmoid(self.encoder_att_g_gate(attn_layer))
f_gate = self.tanh(self.encoder_att_f_gate(attn_layer))
attn_layer = g_gate * f_gate + (1 - g_gate) * cur_hidden_state
attn[batch_idx, idx, :] = attn_layer
return attn
def forward(self, x):
all_hidden, x = self.encoder_gru(x)
x = all_hidden[:, -2:-1, :]
attn = None
return x, attn
class SelfAttnGRU(nn.Module):
def __init__(self, input_size, hidden_size, bi_dir, num_layers):
super(SelfAttnGRU, self).__init__()
self.hidden_size = hidden_size
bi_dir_hidden_size = hidden_size * 2
self.decoder_gru_cell = nn.GRUCell(input_size, bi_dir_hidden_size)
torch.nn.init.xavier_uniform_(self.decoder_gru_cell.weight_hh)
torch.nn.init.xavier_uniform_(self.decoder_gru_cell.weight_ih)
self.decoder_att_linear = nn.Linear(bi_dir_hidden_size, bi_dir_hidden_size)
torch.nn.init.xavier_uniform_(self.decoder_att_linear.weight)
self.decoder_attn_weighted_ctx = nn.Linear(bi_dir_hidden_size * 2, bi_dir_hidden_size)
torch.nn.init.xavier_uniform_(self.decoder_attn_weighted_ctx.weight)
self.decode_prediction = nn.Linear(bi_dir_hidden_size, BERT_VOCAB_SIZE)
torch.nn.init.xavier_uniform_(self.decode_prediction.weight)
self.tanh = nn.Tanh()
self.sigmoid = nn.Sigmoid()
self.softmax = nn.Softmax(dim=1)
def decode(self, x, attn, y):
preds = torch.zeros([x.shape[0], y.shape[1]])
# TODO: support non bi-dir
softmaxes = torch.zeros([x.shape[0], y.shape[1], BERT_VOCAB_SIZE])
for batch_idx, batch in enumerate(y):
for wrd_idx in range(0, y.shape[1]):
current_word = torch.index_select(batch, 0, torch.tensor(wrd_idx))
if wrd_idx == 0:
x = self.decoder_gru_cell(current_word, x[batch_idx])
else:
x = self.decoder_gru_cell(current_word, x)
softmax = self.decode_prediction(x)
softmaxes[batch_idx][wrd_idx] = softmax
preds[batch_idx][wrd_idx] = torch.argmax(self.softmax(softmax))
# attn_layer = self.decoder_att_linear(x)
# attn_layer = self.softmax(torch.matmul(attn_layer, torch.t(attn[batch_idx])))
# attn_layer = torch.matmul(attn_layer, attn[batch_idx])
# attn_layer = torch.cat((attn_layer, x), dim=1)
# x = self.tanh(self.decoder_attn_weighted_ctx(attn_layer))
return preds, softmaxes
def forward(self, x, attn, y):
labels, softmaxes = self.decode(x, attn, y)
return labels, softmaxes
def train(encoder, decoder, criterion, optimizer_encoder, optimizer_decoder, data, epochs):
# model.cuda()
encoder.train()
decoder.train()
for i in range(0, epochs):
print("------------------")
optimizer_encoder.zero_grad()
optimizer_decoder.zero_grad()
loss = None
for j in range(0, 100):
batch = next(iter(data))
target_labels = torch.tensor(batch['target'])
input_vec = BERT_MODEL(torch.tensor(batch['context']))[0][:, 0:10, :]
output_vec = BERT_MODEL(target_labels)[0]
x, attn = encoder(input_vec)
pred, softmaxes = decoder(x, attn, output_vec)
print("=====")
print(f"ORIGINAL: {BERT_TOKENIZER.decode(target_labels[0])}")
print(f"PRED: {BERT_TOKENIZER.decode(pred[0])}")
print("=====")
target_labels.contiguous().view(-1)
if loss is None:
loss = criterion(softmaxes[0], target_labels[0])
else:
loss += criterion(softmaxes[0], target_labels[0])
loss.backward(retain_graph=True)
optimizer_encoder.step()
optimizer_decoder.step()
# for n, w in encoder.named_parameters():
# if w.grad is None:
# continue
# else:
# print(n)
# print(w.grad)
# if torch.sum(w.grad) == 0:
# print(n)
print(i, loss)
print("------------------")```