How to implement efficiently scatter_max() in Pytorch?

I implemented scatter_max() function in Pytorch like https://www.tensorflow.org/api_docs/python/tf/scatter_max . My implementation was too slow.

Any suggestion for speed up?

Hi,

How did you implemented it?
Using gather / max / scatter should be decent no?

Thanks for the reply @albanD . I naively used for loop to find max value and assign it to the table and
repeat this process all data point in one batch
I am not familiar with gather() and scatter() function, so could you give more hints?

def scatter_max(attn_scores, indices):
    tables = torch.zeros_like(attn_scores)
    for score, idx in zip(attn_scores, indices):
        max_score = tables[idx]
        if score > max_score:
            tables[idx] = score
    return tables

Could you give more details on the size of the inputs? In particular how many dimensions they have?

sure!
The function I wrote is for one example in batch so attn_scores and indices are 1-d FloatTensor and 1d LongTensor.
But I want to implement a function that works for batch.(like scatter_add() in pytorch )
so attn_scores would be [batch_size, number of steps] 2d FloatTensor and indices also [batch_size, number of steps] 2d LongTensor
Thanks

Considering the full example.
I am not sure to understand what are the indices supposed to index if the indices tensor has the same size as the input? Also is the original value you want to apply the “max” with always a Tensor full of zeros?

For example if i use scatter_add():

vocab_dist  = output_layer(hiddens)  # output layer is linear layer and output size: [batch_size, 50000]
scores = torch.randn(16, 400)   # [batch size, number of time steps]
indices = torch.randint(0, 50000, (16, 400))   # [batch_size, number of steps]
final_dist = vocab_dist.scatter_add(1, indices, scores)

Instead of adding up, I want to use max.

Thanks !

My idea was this (not that scatter and scatter_add only exist as inplace ops but you can clone vocab_dist before the scatter_ if you don’t want to change it inplace):

vocab_dist  = output_layer(hiddens)  # output layer is linear layer and output size: [batch_size, 50000]
scores = torch.randn(16, 400)   # [batch size, number of time steps]
indices = torch.randint(0, 50000, (16, 400))   # [batch_size, number of steps]
current_values = vocab_dist.gather(1, indices)
new_val = torch.max(current_values, scores)
vocab_dist.scatter_(1, indices, new_val)

Thanks. but I am not sure but i think that’s not what I wanted.
I expect scatter_max() like this.

src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out = src.new_zeros((2, 6))

out = scatter_max(src, index, out=out)

print(out)
tensor([[0, 0, 4, 3, 2, 0],
        [2, 4, 3, 0, 0, 0]])


Ho,

That’s not what scatter does even without the max :smiley:
You want something like put_ but replacing the accumulate argument with a max ?

oh yes it is :slight_smile: That’s what I want
how can I do that? I still have no idea :frowning:

That will be much trickier I’m afraid… I can’t think of an efficient way to do it :confused:

okay. Thanks :smile:
I found out the github repository(https://github.com/rusty1s/pytorch_scatter) which implements scatter_add(), scatter_max() etc.

4 Likes

Ho nice ! Then that will be the most efficient way then :slight_smile: