How to identify the in-place operation in pytorch?

I’ve written some code for text summarization based on the deep reinforced model for abstractive summarization. But when I call backward(), there is some error indicating some in-place operation is executed on some variable, which needs gradient computation.
I’m wondering if there is a way to identify in-place operation in pytorch.

Right below is my code snippet in the forward function. I really appreciate it if someone can help me on this. Thanks so much.

    # PyTorch device location
    device = torch.device('cuda') if cuda and torch.cuda.is_available() else torch.device('cpu')
    
    # For DataParallel with pack_pad and pad_packed
    docs = docs[:,:doc_lens[0]].contiguous()
    bsz, in_seq = docs.size()
    
    # Convert OOV token to UNK
    inputs = docs.clone()
    input_mask = inputs.ge( self.vocab_size )
    inputs.masked_fill_( input_mask, 3 ) # <UNK> token
    
    # Document Word Embedding
    # dembeds: bsz x T_e x emb_size
    dembeds = self.embed( inputs )
    dembeds = F.relu( self.dropout_embed( dembeds ) )
    
    # Pack the embedding sequence
    packed_dembeds = pack_padded_sequence( dembeds, doc_lens, batch_first=True )
    
    # Bidirectional LSTM
    # encoder_hiddens: used to initialize decoder 
    packed_ehiddens, encoder_hiddens = self.encoder( packed_dembeds )
    
    # Unpack ehiddens
    # ehiddens: bsz x T_e x (2*ehid_size)
    ehiddens = pad_packed_sequence( packed_ehiddens, batch_first=True )[0]
    
    # Decoder
    _, target_length = sums.size()
    
    output_mask = sums.ge( self.vocab_size )
    sums.masked_fill_( output_mask, 3 )
    
    # Target Summary Word Embedding
    # sembeds: bsz x T_d x emb_size
    sembeds = self.embed( sums )
    sembeds = F.relu( self.dropout_embed( sembeds ) )
    
    # Decoder start token
    decoder_input_0 = sembeds[:,0:1,:] # SOS token
    
    # Rewrap Encoder Hidden
    decoder_hiddens_0 = [ torch.cat( torch.split( _, 1, dim=0 ), dim=-1 ) for _ in encoder_hiddens ]
    
    # Mask for Encoder Attention
    # en_mask: bsz x 1 x T_e
    en_mask = docs.eq(0).unsqueeze(1)
    
    # batch and token index for Copy Attention
    batch_indices = torch.arange(0, bsz).long()
    batch_indices = batch_indices.expand(in_seq, bsz).transpose(0,1).contiguous().view(-1)
    idx_repeat = torch.arange(0, in_seq).repeat( bsz ).long()
    word_indices = docs.view(-1) # word index in vocab
    
    numbers = docs.view(-1).tolist()
    set_numbers = list(set(numbers)) # all unique numbers
    
    if 0 in set_numbers:
        set_numbers.remove(0)
    
    c = Counter(numbers)
    dup_list = [k for k in set_numbers if (c[k]>1)]
    
    # Cache probs of all timesteps
    p_y = [] 
    
    # Initialize decoder input and hidden
    decoder_input = decoder_input_0
    decoder_hiddens = decoder_hiddens_0
    
    # Decoder unidirectional LSTM
    for t in range( 1, target_length+1 ):
        
        # h_dt: bsz x 1 x dhid_size
        # decoder_hiddens: h_t, c_t
        h_dt, decoder_hiddens = self.decoder( decoder_input, decoder_hiddens )
        
        # Intra-Temporal Attention
        # e_t: bsz x 1 x T_e
        e_t = torch.matmul( h_dt, self.We_attn )
        e_t = torch.bmm( e_t, ehiddens.transpose(1,2) )
        
        if t == 1:
            ep_t = torch.exp( e_t ) # bsz x 1 x T_e
            e = e_t
        else:
            ep_t = torch.exp( e_t ) / torch.sum( torch.exp( e ), dim=1, keepdim=True ) # bsz x 1 x T_e
            e = torch.cat( [e, e_t], dim=1 ) # bsz x t x T_e
            
        # Encoder Attention
        ep_t.masked_fill_( en_mask, 0 )
        en_alpha_t = ep_t / torch.sum( ep_t, dim=2, keepdim=True )
        
        # Encoder Context 
        # en_context_t: bsz x 1 x (2*ehid_size)
        en_context_t = torch.bmm( en_alpha_t, ehiddens )
        
        # Decoder Context Vector
        if t == 1:
            de_context_t = torch.zeros( ( bsz, 1, self.dhid_size ), device=device )
            dhidden = h_dt
        else:
            # Intra-Decoder Attention
            # ed_t: bsz x 1 x t-1
            ed_t = torch.matmul( h_dt, self.Wd_attn )
            ed_t = torch.bmm( ed_t, dhidden.transpose(1,2) )
            
            de_alpha_t = F.softmax( ed_t, dim=2 )
            de_context_t = torch.bmm( de_alpha_t, dhidden )
            
            dhidden = torch.cat( [dhidden, h_dt], dim=1 )
        
        # Merged_context: bsz x 1 x (3*dhid_size)
        merged_context = torch.cat( [ h_dt, en_context_t, de_context_t ], dim=2 )
        
        # p(y_t|u=0)
        # p_yt_u0: bsz x 1 x vocab_size
        p_yt_u0 = F.softmax( self.out( self.dropout( merged_context ) ), dim=2 )
        oovs = torch.zeros( (bsz, 1, self.max_oov), device=device ) + 1.0/self.vocab_size # small epislon to avoid zero prob
        p_yt_u0 = torch.cat( [p_yt_u0, oovs], dim=2 )
        
        # p(u=1)
        # p_u1: bsz x 1 x 1
        p_u1 = F.sigmoid( self.copy( self.dropout( merged_context ) ) )
        
        # Encoder Attention Distribution
        # p(y_t|u=1)
        attn = en_alpha_t.squeeze(1)
        masked_idx_sum = torch.zeros( (bsz, in_seq), device=device)
        dup_attn_sum = torch.zeros( (bsz, in_seq), device=device )

        for dup in dup_list:
            mask = docs.eq( dup ).float()
            masked_idx_sum += mask
            attn_mask = mask * attn
            attn_sum = attn_mask.sum( 1,keepdim=True )
            dup_attn_sum += mask * attn_sum
        attn = attn * (1-masked_idx_sum) + dup_attn_sum
        
        p_yt_u1 = torch.zeros( (bsz, self.vocab_size+self.max_oov), device=device )
        p_yt_u1[ batch_indices, word_indices] += attn[ batch_indices, idx_repeat ]
        p_yt_u1 = p_yt_u1.unsqueeze(1) # bsz x 1 x 1
        
        # p(y_t): bsz x 1 x (vocab_size+max_oov)
        p_yt = p_u1 * p_yt_u1 + ( 1-p_u1 ) * p_yt_u0
        
        # Concatenate for Training
        p_y.append( p_yt )

        # Scheduled Sampling
        use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
        if use_teacher_forcing:
            if t < target_length: # limit to feasible tokens
                decoder_input = sembeds[:, t:t+1, :]  
        else:
            topv, topi = p_yt.topk( 1, dim=2 )
            next_token = topi.squeeze(1).detach()
            next_mask = next_token.ge( self.vocab_size )
            next_token.masked_fill_( next_mask, 3 ) # OOV -> UNK
            decoder_input = F.relu( self.dropout_embed( self.embed( next_token ) ) )
        
    # log_p_y: bsz x (T_d-1) x vocab_size
    p_y = torch.cat( p_y, dim=1 )
    log_p_y = torch.log( p_y )
1 Like

If you just want to find in-place operations, maybe you can look for operations ends with underscore _. For instance, ‘inputs.masked_fill_( input_mask, 3 )’ in your case. Or more general, to consult with official doc: https://pytorch.org/docs/master/tensors.html to see if there is any returns from the operator.

1 Like

Thanks for the pointer. I actually found out that it’s ep_t.masked_fill_( en_mask, 0 ) that causes this error.
But I’m just wondering why would such operation result in the error?

This link might be useful: Autograd docs