I am using the bert implementation in PyTorch.
When I am doing the forward of the BertEmbeddings, this code is executed :
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
"""
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
device = input_ids.device if input_ids is not None else inputs_embeds.device
if position_ids is None:
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(input_shape)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
embeddings.retain_grad()
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
I want the gradient of the embedding, so I added a line in the end of the forward method:
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
embeddings.retain_grad() # this line was added
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
But when I do loss.backward(), the grad is always None.
I also tried loss.backward(keep_graph=True), but I get the error:
TypeError: backward() got an unexpected keyword argument ‘keep_graph’.
I will also state that embeddings.requires_grad is True
How can I save the gradient ?