I’ve written some code for text summarization based on the deep reinforced model for abstractive summarization. But when I call backward(), there is some error indicating some in-place operation is executed on some variable, which needs gradient computation.
I’m wondering if there is a way to identify in-place operation in pytorch.
Right below is my code snippet in the forward function. I really appreciate it if someone can help me on this. Thanks so much.
# PyTorch device location
device = torch.device('cuda') if cuda and torch.cuda.is_available() else torch.device('cpu')
# For DataParallel with pack_pad and pad_packed
docs = docs[:,:doc_lens[0]].contiguous()
bsz, in_seq = docs.size()
# Convert OOV token to UNK
inputs = docs.clone()
input_mask = inputs.ge( self.vocab_size )
inputs.masked_fill_( input_mask, 3 ) # <UNK> token
# Document Word Embedding
# dembeds: bsz x T_e x emb_size
dembeds = self.embed( inputs )
dembeds = F.relu( self.dropout_embed( dembeds ) )
# Pack the embedding sequence
packed_dembeds = pack_padded_sequence( dembeds, doc_lens, batch_first=True )
# Bidirectional LSTM
# encoder_hiddens: used to initialize decoder
packed_ehiddens, encoder_hiddens = self.encoder( packed_dembeds )
# Unpack ehiddens
# ehiddens: bsz x T_e x (2*ehid_size)
ehiddens = pad_packed_sequence( packed_ehiddens, batch_first=True )[0]
# Decoder
_, target_length = sums.size()
output_mask = sums.ge( self.vocab_size )
sums.masked_fill_( output_mask, 3 )
# Target Summary Word Embedding
# sembeds: bsz x T_d x emb_size
sembeds = self.embed( sums )
sembeds = F.relu( self.dropout_embed( sembeds ) )
# Decoder start token
decoder_input_0 = sembeds[:,0:1,:] # SOS token
# Rewrap Encoder Hidden
decoder_hiddens_0 = [ torch.cat( torch.split( _, 1, dim=0 ), dim=-1 ) for _ in encoder_hiddens ]
# Mask for Encoder Attention
# en_mask: bsz x 1 x T_e
en_mask = docs.eq(0).unsqueeze(1)
# batch and token index for Copy Attention
batch_indices = torch.arange(0, bsz).long()
batch_indices = batch_indices.expand(in_seq, bsz).transpose(0,1).contiguous().view(-1)
idx_repeat = torch.arange(0, in_seq).repeat( bsz ).long()
word_indices = docs.view(-1) # word index in vocab
numbers = docs.view(-1).tolist()
set_numbers = list(set(numbers)) # all unique numbers
if 0 in set_numbers:
set_numbers.remove(0)
c = Counter(numbers)
dup_list = [k for k in set_numbers if (c[k]>1)]
# Cache probs of all timesteps
p_y = []
# Initialize decoder input and hidden
decoder_input = decoder_input_0
decoder_hiddens = decoder_hiddens_0
# Decoder unidirectional LSTM
for t in range( 1, target_length+1 ):
# h_dt: bsz x 1 x dhid_size
# decoder_hiddens: h_t, c_t
h_dt, decoder_hiddens = self.decoder( decoder_input, decoder_hiddens )
# Intra-Temporal Attention
# e_t: bsz x 1 x T_e
e_t = torch.matmul( h_dt, self.We_attn )
e_t = torch.bmm( e_t, ehiddens.transpose(1,2) )
if t == 1:
ep_t = torch.exp( e_t ) # bsz x 1 x T_e
e = e_t
else:
ep_t = torch.exp( e_t ) / torch.sum( torch.exp( e ), dim=1, keepdim=True ) # bsz x 1 x T_e
e = torch.cat( [e, e_t], dim=1 ) # bsz x t x T_e
# Encoder Attention
ep_t.masked_fill_( en_mask, 0 )
en_alpha_t = ep_t / torch.sum( ep_t, dim=2, keepdim=True )
# Encoder Context
# en_context_t: bsz x 1 x (2*ehid_size)
en_context_t = torch.bmm( en_alpha_t, ehiddens )
# Decoder Context Vector
if t == 1:
de_context_t = torch.zeros( ( bsz, 1, self.dhid_size ), device=device )
dhidden = h_dt
else:
# Intra-Decoder Attention
# ed_t: bsz x 1 x t-1
ed_t = torch.matmul( h_dt, self.Wd_attn )
ed_t = torch.bmm( ed_t, dhidden.transpose(1,2) )
de_alpha_t = F.softmax( ed_t, dim=2 )
de_context_t = torch.bmm( de_alpha_t, dhidden )
dhidden = torch.cat( [dhidden, h_dt], dim=1 )
# Merged_context: bsz x 1 x (3*dhid_size)
merged_context = torch.cat( [ h_dt, en_context_t, de_context_t ], dim=2 )
# p(y_t|u=0)
# p_yt_u0: bsz x 1 x vocab_size
p_yt_u0 = F.softmax( self.out( self.dropout( merged_context ) ), dim=2 )
oovs = torch.zeros( (bsz, 1, self.max_oov), device=device ) + 1.0/self.vocab_size # small epislon to avoid zero prob
p_yt_u0 = torch.cat( [p_yt_u0, oovs], dim=2 )
# p(u=1)
# p_u1: bsz x 1 x 1
p_u1 = F.sigmoid( self.copy( self.dropout( merged_context ) ) )
# Encoder Attention Distribution
# p(y_t|u=1)
attn = en_alpha_t.squeeze(1)
masked_idx_sum = torch.zeros( (bsz, in_seq), device=device)
dup_attn_sum = torch.zeros( (bsz, in_seq), device=device )
for dup in dup_list:
mask = docs.eq( dup ).float()
masked_idx_sum += mask
attn_mask = mask * attn
attn_sum = attn_mask.sum( 1,keepdim=True )
dup_attn_sum += mask * attn_sum
attn = attn * (1-masked_idx_sum) + dup_attn_sum
p_yt_u1 = torch.zeros( (bsz, self.vocab_size+self.max_oov), device=device )
p_yt_u1[ batch_indices, word_indices] += attn[ batch_indices, idx_repeat ]
p_yt_u1 = p_yt_u1.unsqueeze(1) # bsz x 1 x 1
# p(y_t): bsz x 1 x (vocab_size+max_oov)
p_yt = p_u1 * p_yt_u1 + ( 1-p_u1 ) * p_yt_u0
# Concatenate for Training
p_y.append( p_yt )
# Scheduled Sampling
use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
if use_teacher_forcing:
if t < target_length: # limit to feasible tokens
decoder_input = sembeds[:, t:t+1, :]
else:
topv, topi = p_yt.topk( 1, dim=2 )
next_token = topi.squeeze(1).detach()
next_mask = next_token.ge( self.vocab_size )
next_token.masked_fill_( next_mask, 3 ) # OOV -> UNK
decoder_input = F.relu( self.dropout_embed( self.embed( next_token ) ) )
# log_p_y: bsz x (T_d-1) x vocab_size
p_y = torch.cat( p_y, dim=1 )
log_p_y = torch.log( p_y )