Hi,
I have a simple learnable positional embedding and want to mask based on the pad index. However, it breaks the graph when I do expand_as
. Should I just add clone()
after this?
pos_embed = nn.Embedding(n_position, dim)
def forward(embeds, entities):
# batch, n, dim
seq_len = embeds.shape[1]
mask_idx = (entities == 0).nonzero(as_tuple=True)
positions = torch.arange(seq_len, dtype=torch.long, device=entity_embeds.device).unsqueeze(0)
pos_embeds = self.pos_embed(positions).expand_as(embeds).clone()
# this would break the embeddings when I do in place
# pos_embeds = self.pos_embed(positions).expand_as(embeds)
# mask based on the index
pos_embeds[mask_idx] = 0
return embeds + pos_embed
Thanks