TracerWarning: There are 2 live references to the data region being modified when tracing in-place operator masked_fill_

I’m converting another model class into jit trace and keep getting tracer warnings. For the following model class for global attention.

class GlobalAttentionGeneral(nn.Module):
    def __init__(self, idf, cdf):
        super(GlobalAttentionGeneral, self).__init__()
        self.conv_context = conv1x1(cdf, idf)
        self.sm = nn.Softmax(dim=1)
        self.mask = None

    def applyMask(self, mask):
        self.mask = mask  # batch x sourceL

    def forward(self, input, context):
        """
            input: batch x idf x ih x iw (queryL=ihxiw)
            context: batch x cdf x sourceL
        """
        ih, iw = input.size(2), input.size(3)
        queryL = ih * iw
        batch_size, sourceL = context.size(0), context.size(2)

        # --> batch x queryL x idf
        target = input.view(batch_size, -1, queryL)
        targetT = torch.transpose(target, 1, 2).contiguous()
        # batch x cdf x sourceL --> batch x cdf x sourceL x 1
        sourceT = context.unsqueeze(3)
        # --> batch x idf x sourceL
        sourceT = self.conv_context(sourceT).squeeze(3)

        # Get attention
        # (batch x queryL x idf)(batch x idf x sourceL)
        # -->batch x queryL x sourceL
        attn = torch.bmm(targetT, sourceT)
        # --> batch*queryL x sourceL
        attn = attn.view(batch_size*queryL, sourceL)
        if self.mask is not None:
            # batch_size x sourceL --> batch_size*queryL x sourceL
            mask = self.mask.repeat(queryL, 1)
            attn.data.masked_fill_(mask.data, -float('inf'))
        attn = self.sm(attn)  # Eq. (2)
        # --> batch x queryL x sourceL
        attn = attn.view(batch_size, queryL, sourceL)
        # --> batch x sourceL x queryL
        attn = torch.transpose(attn, 1, 2).contiguous()

        # (batch x idf x sourceL)(batch x sourceL x queryL)
        # --> batch x idf x queryL
        weightedContext = torch.bmm(sourceT, attn)
        weightedContext = weightedContext.view(batch_size, -1, ih, iw)
        attn = attn.view(batch_size, -1, ih, iw)

        return weightedContext, attn

I keep getting an error for the following line:

GlobalAttention.py:108: TracerWarning: There are 2 live references to 
the data region being modified when tracing in-place operator 
masked_fill_. This might cause the trace to be incorrect, because all 
other views that also reference this data will not not reflect this 
change in the trace! On the other hand, if all other views use the 
same memory chunk, but are disjoint (e.g. are outputs of torch.split),
this might still be safe.

  attn.data.masked_fill_(mask.data, -float('inf'))

Whats the best way to replace this line?

If you’re using .data, you’re almost certainly doing it wrong with tracing (and it’s not too great an idea in general).
If I’m understanding this correctly, you want the mask to leave data intact by what is not masked and set attn to -math.inf where it isn’t. The easiest way might be to take the 1, 0-valued (float) mask and do attn = attn + mask.log().

Best regards

Thomas

1 Like