Encoder hidden state weights have gradient 0 in seq2seq model

Hi everyone! I’ve been working on implementing a seq2seq model with self-attention in pytorch. It’s been giving me some trouble for a few days now. Essentially, the model always learns to repeat the same word over and over. I’ve noticed that the encoder weights for the hidden state are always 0. However, the hidden state bias has non-zero gradients Does anyone have any idea what’s going on here?

class EncoderGRU(nn.Module):
	def __init__(self, input_size, hidden_size, bi_dir=True, num_layers=1):
		super(EncoderGRU, self).__init__()
		self.encoder_gru = nn.GRU(input_size, hidden_size, bidirectional=bi_dir, num_layers=num_layers)
		self.encoder_gru.apply(self.init_weights)

		self.softmax = nn.Softmax(dim=1)
		self.tanh = nn.Tanh()
		self.sigmoid = nn.Sigmoid()

		self.hidden_size = hidden_size

		bi_dir_hidden_size = hidden_size * 2

		self.encoder_att_linear = nn.Linear(bi_dir_hidden_size, bi_dir_hidden_size)
		torch.nn.init.xavier_uniform_(self.encoder_att_linear.weight)

		self.encoder_att_g_gate = nn.Linear(bi_dir_hidden_size * 2, bi_dir_hidden_size)
		torch.nn.init.xavier_uniform_(self.encoder_att_g_gate.weight)

		self.encoder_att_f_gate = nn.Linear(bi_dir_hidden_size * 2, bi_dir_hidden_size)
		torch.nn.init.xavier_uniform_(self.encoder_att_f_gate.weight)

	@staticmethod
	def init_weights(w):
		if type(w) == nn.GRUCell:
			nn.init.xavier_uniform_(w)

	def calc_encoder_attn(self, x):

		attn = torch.zeros([x.shape[0], x.shape[1], self.hidden_size * 2])

		for batch_idx, batch in enumerate(x):
			for idx in range(0, batch.shape[0]):
				cur_hidden_state = torch.index_select(batch, 0, torch.tensor(idx))
				attn_layer = torch.t(self.encoder_att_linear(cur_hidden_state))
				attn_layer = self.softmax(torch.matmul(torch.t(attn_layer), torch.t(batch)))
				attn_layer = torch.matmul(attn_layer, batch)
				attn_layer = torch.cat((attn_layer, cur_hidden_state), dim=1)
				g_gate = self.sigmoid(self.encoder_att_g_gate(attn_layer))
				f_gate = self.tanh(self.encoder_att_f_gate(attn_layer))
				attn_layer = g_gate * f_gate + (1 - g_gate) * cur_hidden_state
				attn[batch_idx, idx, :] = attn_layer
		return attn

	def forward(self, x):
		all_hidden, x = self.encoder_gru(x)
		x = all_hidden[:, -2:-1, :]
		attn = None
		return x, attn


class SelfAttnGRU(nn.Module):
	def __init__(self, input_size, hidden_size, bi_dir, num_layers):
		super(SelfAttnGRU, self).__init__()
		self.hidden_size = hidden_size
		bi_dir_hidden_size = hidden_size * 2

		self.decoder_gru_cell = nn.GRUCell(input_size, bi_dir_hidden_size)
		torch.nn.init.xavier_uniform_(self.decoder_gru_cell.weight_hh)
		torch.nn.init.xavier_uniform_(self.decoder_gru_cell.weight_ih)

		self.decoder_att_linear = nn.Linear(bi_dir_hidden_size, bi_dir_hidden_size)
		torch.nn.init.xavier_uniform_(self.decoder_att_linear.weight)

		self.decoder_attn_weighted_ctx = nn.Linear(bi_dir_hidden_size * 2, bi_dir_hidden_size)
		torch.nn.init.xavier_uniform_(self.decoder_attn_weighted_ctx.weight)

		self.decode_prediction = nn.Linear(bi_dir_hidden_size, BERT_VOCAB_SIZE)
		torch.nn.init.xavier_uniform_(self.decode_prediction.weight)

		self.tanh = nn.Tanh()
		self.sigmoid = nn.Sigmoid()
		self.softmax = nn.Softmax(dim=1)

	def decode(self, x, attn, y):
		preds = torch.zeros([x.shape[0], y.shape[1]])
		# TODO: support non bi-dir
		softmaxes = torch.zeros([x.shape[0], y.shape[1], BERT_VOCAB_SIZE])
		for batch_idx, batch in enumerate(y):
			for wrd_idx in range(0, y.shape[1]):
				current_word = torch.index_select(batch, 0, torch.tensor(wrd_idx))

				if wrd_idx == 0:
					x = self.decoder_gru_cell(current_word, x[batch_idx])
				else:
					x = self.decoder_gru_cell(current_word, x)

				softmax = self.decode_prediction(x)
				softmaxes[batch_idx][wrd_idx] = softmax
				preds[batch_idx][wrd_idx] = torch.argmax(self.softmax(softmax))

				# attn_layer = self.decoder_att_linear(x)
				# attn_layer = self.softmax(torch.matmul(attn_layer, torch.t(attn[batch_idx])))
				# attn_layer = torch.matmul(attn_layer, attn[batch_idx])
				# attn_layer = torch.cat((attn_layer, x), dim=1)
				# x = self.tanh(self.decoder_attn_weighted_ctx(attn_layer))
		return preds, softmaxes

	def forward(self, x, attn, y):
		labels, softmaxes = self.decode(x, attn, y)
		return labels, softmaxes

def train(encoder, decoder, criterion, optimizer_encoder, optimizer_decoder, data, epochs):
	# model.cuda()
	encoder.train()
	decoder.train()
	for i in range(0, epochs):
		print("------------------")
		optimizer_encoder.zero_grad()
		optimizer_decoder.zero_grad()
		loss = None
		for j in range(0, 100):
			batch = next(iter(data))
			target_labels = torch.tensor(batch['target'])

			input_vec = BERT_MODEL(torch.tensor(batch['context']))[0][:, 0:10, :]
			output_vec = BERT_MODEL(target_labels)[0]
			x, attn = encoder(input_vec)
			pred, softmaxes = decoder(x, attn, output_vec)

			print("=====")
			print(f"ORIGINAL: {BERT_TOKENIZER.decode(target_labels[0])}")
			print(f"PRED: {BERT_TOKENIZER.decode(pred[0])}")
			print("=====")

			target_labels.contiguous().view(-1)
			if loss is None:
				loss = criterion(softmaxes[0], target_labels[0])
			else:
				loss += criterion(softmaxes[0], target_labels[0])

		loss.backward(retain_graph=True)
		optimizer_encoder.step()
		optimizer_decoder.step()
		# for n, w in encoder.named_parameters():
		# 	if w.grad is None:
		# 		continue
		# 	else:
		# 		print(n)
		# 		print(w.grad)
		# 	if torch.sum(w.grad) == 0:
		# 		print(n)

		print(i, loss)
		print("------------------")```