Is there any easy way to compute a ``head-mask-specified" pooling computation?

For example, I have a head-mask h = [1, 1, 0, 0, 1], where 1 denotes the head of a span.
And, I also have a tense a = torch.tenser(5, 100). I want to implement a function, which should compute a new tensor b with a size of (3, 100), such that b[0] = a[0], b[1] = a[1] + a[2] + a[3] (as denoted in h), b[2] = a[4].
Using for computation is easy, but is there any efficient computation? What if h and a are batched?

Hi, ryancc
Still confused about what you are talking about. Please clarify more clearly, maybe I can help.

Hi, thank you very much.

Say I have a sentence consisting of two words: S = [“Definitely”, “not”], and what I want is to transfer S into an embedding matrix T with a size of (2, 100), where each row represents a word.

I want to adopt BERT embeddings. But in BERT, each word is represented as a sub-word unit. This means that S will be represented as [“Def”, “##in”, “##ite”, “##ly”, “not”] ( “Definitely” is tokenized as “Def”, “##in”, “##ite”, “##ly”). BERT will output an embedding matrix H with a size of (5, 100) :(.

My goal is to merge some rows of H according to the sub-word units.
For example, for “Definitely”, I should merge the embeddings of [“Def”, “##in”, “##ite”, “##ly”] to get its presentation.

In my current method, I use a head mask vector h = [1, 0, 0, 0, 1] to record the “head” of each word, where 1 indicates the head position:
h = [
1, -> “Def”
0, -> “##in”
0, -> “##ite”
0, -> “##ly”
1 -> “not”
]
So I should merge rows which have a head mask of 0 to that having a head mask of 1. I have to use the for computation to enumerate each element in h , which is slow and can not batchfy.

Could you give me some efficient method to do the above computation?

I have found an efficient way to deal with it.

import torch

N = 6
C = 4

data = torch.randn(N, C)
head_mask = torch.tensor([1, 1, 0, 0, 1, 0])
head_index = (head_mask == 1).nonzero().view(-1) # 0, 1, 4
head_index_expand = torch.cat([head_index, torch.tensor([head_mask.numel()]) ] ) # 0,1, 4, 6
head_index_diff = (head_index_expand - torch.roll(head_index_expand, 1))[1:] # 1, 3, 2
splits = torch.split(data, head_index_diff.tolist())

print('ori data')
print(data)
print('after spliting')
print(splits)

# Each item contains a word, using combine all the elements in each item in the list
# you get the words you want.
1 Like

Hi, Naruto-Sasuke:
I have read your solution. It is exactly what I am looking for!
Thank you very much!