Self-Attention (on words) and masking

I have a simple model for text classification. It has an attention layer after an RNN, which computes a weighted average of the hidden states of the RNN. I sort each batch by length and use pack_padded_sequence in order to avoid computing the masked timesteps. The model works but i want to apply masking on the attention scores/weights.

Here is my Layer:

class SelfAttention(nn.Module):
    def __init__(self, hidden_size, batch_first=False):
        super(SelfAttention, self).__init__()

        self.hidden_size = hidden_size
        self.batch_first = batch_first

        self.att_weights = Parameter(torch.Tensor(1, hidden_size),
                                     requires_grad=True)

        init.xavier_uniform(self.att_weights.data)

    def get_mask(self):
        pass

    def forward(self, inputs):

        if isinstance(inputs, PackedSequence):
            # unpack output
            inputs, lengths = pad_packed_sequence(inputs,
                                                  batch_first=self.batch_first)
        if self.batch_first:
            batch_size, max_len = inputs.size()[:2]
        else:
            max_len, batch_size = inputs.size()[:2]

        # att = torch.mul(inputs, self.att_weights.expand_as(inputs))
        # att = att.sum(-1)
        weights = torch.bmm(inputs,
                            self.att_weights  # (1, hidden_size)
                            .permute(1, 0)  # (hidden_size, 1)
                            .unsqueeze(0)  # (1, hidden_size, 1)
                            .repeat(batch_size, 1, 1) # (batch_size, hidden_size, 1)
                            )

        attentions = F.softmax(F.relu(weights.squeeze()))

        # apply weights
        weighted = torch.mul(inputs, attentions.unsqueeze(-1).expand_as(inputs))

        # get the final fixed vector representations of the sentences
        representations = weighted.sum(1).squeeze()

        return representations, attentions

I tried adding this, but it obviously fails with an error RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

        ...
        attentions = F.softmax(F.relu(weights.squeeze()))

        # apply masking based on the sentence lengths
        for i, l in enumerate(lengths[1:], 1):  # skip the first sentence
            attentions[i, l:] = 0

        # apply weights
        weighted = torch.mul(inputs, attentions.unsqueeze(-1).expand_as(inputs))
        ...

How can i mask the attention weights?

1 Like

you can instead create a 0,1 mask variable, and

attentions = attentions * mask

@ruotianluo Thanks. I will try it and post back. Do i have to detach() the variable or something like that? I am not sure how this operation will affect the flow of the gradients.

So this is what i came up with:

class SelfAttention(nn.Module):
    def __init__(self, hidden_size, batch_first=False):
        super(SelfAttention, self).__init__()

        self.hidden_size = hidden_size
        self.batch_first = batch_first

        self.att_weights = Parameter(torch.Tensor(1, hidden_size),
                                     requires_grad=True)

        init.xavier_uniform(self.att_weights.data)

    def get_mask(self):
        pass

    def forward(self, inputs):

        if isinstance(inputs, PackedSequence):
            # unpack output
            inputs, lengths = pad_packed_sequence(inputs,
                                                  batch_first=self.batch_first)
        if self.batch_first:
            batch_size, max_len = inputs.size()[:2]
        else:
            max_len, batch_size = inputs.size()[:2]

        # apply attention layer
        weights = torch.bmm(inputs,
                            self.att_weights  # (1, hidden_size)
                            .permute(1, 0)  # (hidden_size, 1)
                            .unsqueeze(0)  # (1, hidden_size, 1)
                            .repeat(batch_size, 1, 1)
                            # (batch_size, hidden_size, 1)
                            )

        attentions = F.softmax(F.relu(weights.squeeze()))

        # create mask based on the sentence lengths
        mask = Variable(torch.ones(attentions.size())).cuda()
        for i, l in enumerate(lengths):  # skip the first sentence
            if l < max_len:
                mask[i, l:] = 0

        # apply mask and renormalize attention scores (weights)
        masked = attentions * mask
        _sums = masked.sum(-1).expand_as(attentions)  # sums per row
        attentions = masked.div(_sums)

        # apply attention weights
        weighted = torch.mul(inputs, attentions.unsqueeze(-1).expand_as(inputs))

        # get the final fixed vector representations of the sentences
        representations = weighted.sum(1).squeeze()

        return representations, attentions

it works and the networks performs better.
Hope it helps someone.

However, I am concerned about the gradients of the attention layer. I mean how does the masking operation affect them? Is this the right way? Or is there a more elegant way of achieving the same thing?

8 Likes

Hello Christos,

With

attentions = Variable(torch.randn(5,10).cuda())
max_len = attentions.size(1)
lengths = ((torch.arange(0,5)+5).long().cuda())

as an example, I think you can do this

faster if you do

idxes = torch.arange(0,max_len,out=torch.LongTensor(max_len)).unsqueeze(0).cuda() # some day, you'll be able to directly do this on cuda
mask = Variable((idxes<lengths.unsqueeze(1)).float())

(works on master / 0.2, you need to expand_as(attention) or so on 0.1.12)

If you multiply the output by 0 with the mask, this will propagate a gradient of 0 to attention. I think this is the right way in general, but I don’t have the expertise to say how best handle the end of the sequence…

Best regards

Thomas

9 Likes

hi @cbaziotis, thanks for sharing,
for the final representation, can you point me to the related paper or material? thanks

@lyan62 I’ve used this attention mechanism, in the following paper: https://arxiv.org/abs/1804.06659. I will share the source code of our models in this repo https://github.com/cbaziotis/ntua-slp-semeval2018-task3 in a few days. In the meantime you can find my implementation of the simple self-attention mechanism here: https://gist.github.com/cbaziotis/94e53bdd6e4852756e0395560ff38aa4

3 Likes

@cbaziotis, Can you explain a bit about the lengths parameter in forward()? I was looking through your gist https://gist.github.com/cbaziotis/94e53bdd6e4852756e0395560ff38aa4#file-selfattention-py-L31. It looks like that parameter is a Variable, but I am not sure where it comes from.

@yngtodd lengths is a list of integers, in which each number corresponds to the actual length of each sentence.

It is need in order to take into account the zero-padded timesteps. Only by knowing the length of each sentence we can properly calculate the attention weights. We want to assign weight only on the actual words. Without knowing the length, the zero padded words will be assigned a non-zero weight by the attention mechanism.

Hi @cbaziotis, I have a little doubt about how the SelfAttention output should be used …
The representations.shape should be (batch_size, hidden_size) and attentions.shape should be (batch_size, max_len).
How should they be used inside an lstm? What exactly do these two variables represent to me?