I train my model but realized that the gradient w.r.t. one weight (bert.embeddings.word_beddings.weight
) is always 0
(not None
).
My model:
class AdversialBert(nn.Module):
def __init__(self, hyperparams: HyperParameters):
super().__init__()
self.bert = Bert.from_pretrained(hyperparams.original_checkpoint)
self.linear = nn.Linear(in_features=768, out_features=2)
self.hyperparams = hyperparams
self.ce_loss = nn.CrossEntropyLoss()
self.pert_strength = self.hyperparams.perturbation_strength
def n_forward(self, attention_mask: torch.Tensor, target: torch.Tensor, input_ids: torch.Tensor):
embedding = self.bert.embeddings(input_ids)
if self.bert.training:
embedding.retain_grad()
extended_attention_mask = self.bert.get_extended_attention_mask(attention_mask, attention_mask.shape,
device=self.hyperparams.device)
logits, loss = self.calc_logits_loss(embedding=embedding, target=target, attention_mask=extended_attention_mask)
return [logits, loss, embedding]
def a_forward(self, attention_mask: torch.Tensor, target: torch.Tensor, embedding: torch.Tensor):
extended_attention_mask = self.bert.get_extended_attention_mask(attention_mask, attention_mask.shape,
device=self.hyperparams.device)
return self.calc_logits_loss(embedding=embedding, target=target, attention_mask=extended_attention_mask)
def calc_logits(self, embedding: torch.Tensor, attention_mask: torch.Tensor):
encoding = self.bert.encoder(hidden_states=embedding, attention_mask=attention_mask).last_hidden_state
avg_pooled_encoding = torch.mean(input=encoding, dim=1)
logits = self.linear(avg_pooled_encoding)
return logits
def calc_logits_loss(self, embedding: torch.Tensor, target: torch.Tensor, attention_mask: torch.Tensor):
logits = self.calc_logits(embedding=embedding, attention_mask=attention_mask)
loss = self.ce_loss(input=logits, target=target.long())
return [logits, loss]
The update step:
def train_one_batch(self, batch, adversial_bert: AdversialBert):
self.optimizer.zero_grad()
input_ids = batch[0]
attention_mask = batch[1]
target = batch[2]
print("before n_loss", adversial_bert.bert.embeddings.word_embeddings.weight.grad)
print("before n_loss", adversial_bert.bert.embeddings.LayerNorm.bias.grad)
# normal loss
_, n_loss, n_embedding = adversial_bert.n_forward(attention_mask=attention_mask, target=target, input_ids=input_ids)
n_loss.backward(retain_graph=True)
print("after n_loss", adversial_bert.bert.embeddings.word_embeddings.weight.grad)
print("after n_loss", adversial_bert.bert.embeddings.LayerNorm.bias.grad)
# adversial loss
## perturbate the embedding
n_embedding_grad = torch.clone(n_embedding.grad)
shape = n_embedding.shape
### std over all embeddings of the batch (across batch and sequence)
embedding_std = torch.std(n_embedding.reshape([shape[0]*shape[1], shape[2]]), dim=0)
signed_grad = torch.sign(input=n_embedding_grad)
a_embedding = n_embedding + self.hyperparams.perturbation_strength * embedding_std * signed_grad
_, a_loss = adversial_bert.a_forward(attention_mask=attention_mask, embedding=a_embedding, target=target)
print("after a_loss", adversial_bert.bert.embeddings.word_embeddings.weight.grad)
print("after a_loss", adversial_bert.bert.embeddings.LayerNorm.bias.grad)
# embedding difference loss
e_loss = torch.mean(embedding_std) + torch.abs(torch.mean(n_embedding))
a_loss = a_loss * self.hyperparams.a_loss_weight
e_loss = e_loss * self.hyperparams.e_loss_weight
combined_loss = n_loss + a_loss + e_loss
self.optimizer.zero_grad()
combined_loss.backward()
print("after combined_loss", adversial_bert.bert.embeddings.word_embeddings.weight.grad)
print("after combined_loss", adversial_bert.bert.embeddings.LayerNorm.bias.grad)
self.optimizer.step()
The prints for adversial_bert.bert.embeddings.LayerNorm.bias.grad
are:
>>> 0-tensor
>>> non-0-tensor
>>> non-0-tensor
>>> non-0-tensor
But the prints for adversial_bert.bert.embeddings.word_embeddings.weight.grad
are:
>>> 0-tensor
>>> 0-tensor
>>> 0-tensor
>>> 0-tensor
Can someone tell my why adversial_bert.bert.embeddings.word_embeddings.weight.grad
is always 0
, and how to fix it?
Thank you =)