Mask Position Embedding without break the graph


I have a simple learnable positional embedding and want to mask based on the pad index. However, it breaks the graph when I do expand_as. Should I just add clone() after this?

pos_embed = nn.Embedding(n_position, dim)

def forward(embeds, entities):
    # batch, n, dim
    seq_len = embeds.shape[1]
    mask_idx = (entities == 0).nonzero(as_tuple=True)

    positions = torch.arange(seq_len, dtype=torch.long,  device=entity_embeds.device).unsqueeze(0)
    pos_embeds = self.pos_embed(positions).expand_as(embeds).clone()
    # this would break the embeddings when I do in place 
    # pos_embeds = self.pos_embed(positions).expand_as(embeds)
    # mask based on the index
    pos_embeds[mask_idx] = 0
    return embeds + pos_embed