I’m converting another model class into jit trace and keep getting tracer warnings. For the following model class for global attention.
class GlobalAttentionGeneral(nn.Module):
def __init__(self, idf, cdf):
super(GlobalAttentionGeneral, self).__init__()
self.conv_context = conv1x1(cdf, idf)
self.sm = nn.Softmax(dim=1)
self.mask = None
def applyMask(self, mask):
self.mask = mask # batch x sourceL
def forward(self, input, context):
"""
input: batch x idf x ih x iw (queryL=ihxiw)
context: batch x cdf x sourceL
"""
ih, iw = input.size(2), input.size(3)
queryL = ih * iw
batch_size, sourceL = context.size(0), context.size(2)
# --> batch x queryL x idf
target = input.view(batch_size, -1, queryL)
targetT = torch.transpose(target, 1, 2).contiguous()
# batch x cdf x sourceL --> batch x cdf x sourceL x 1
sourceT = context.unsqueeze(3)
# --> batch x idf x sourceL
sourceT = self.conv_context(sourceT).squeeze(3)
# Get attention
# (batch x queryL x idf)(batch x idf x sourceL)
# -->batch x queryL x sourceL
attn = torch.bmm(targetT, sourceT)
# --> batch*queryL x sourceL
attn = attn.view(batch_size*queryL, sourceL)
if self.mask is not None:
# batch_size x sourceL --> batch_size*queryL x sourceL
mask = self.mask.repeat(queryL, 1)
attn.data.masked_fill_(mask.data, -float('inf'))
attn = self.sm(attn) # Eq. (2)
# --> batch x queryL x sourceL
attn = attn.view(batch_size, queryL, sourceL)
# --> batch x sourceL x queryL
attn = torch.transpose(attn, 1, 2).contiguous()
# (batch x idf x sourceL)(batch x sourceL x queryL)
# --> batch x idf x queryL
weightedContext = torch.bmm(sourceT, attn)
weightedContext = weightedContext.view(batch_size, -1, ih, iw)
attn = attn.view(batch_size, -1, ih, iw)
return weightedContext, attn
I keep getting an error for the following line:
GlobalAttention.py:108: TracerWarning: There are 2 live references to
the data region being modified when tracing in-place operator
masked_fill_. This might cause the trace to be incorrect, because all
other views that also reference this data will not not reflect this
change in the trace! On the other hand, if all other views use the
same memory chunk, but are disjoint (e.g. are outputs of torch.split),
this might still be safe.
attn.data.masked_fill_(mask.data, -float('inf'))
Whats the best way to replace this line?