Compute scores for only a subset of inputs

Hi!

I am working on an SRL implementation, wherein the goal is to compute a tag for every possible predicate x span combination. This means that we can end up with an intractable scoring space of

T1 = batch_size x |tokens| x |tokens^2| x |num_tags|

where |tokens| is for predicates and |tokens^2| is for spans

In practice, papers report pruning the predicates and spans to sets of top-k candidates, and computing scores for only these. I can get these top-k no problem, but the issue is that I then need to compute a loss for a gold tensor that does not match the indices of the top-k predicates and top-k spans.

I think the solution is to simply expand the batch_size x (|top-k predicates| * |top-k spans|) x |tags| tensor back into its original size matching the shape tensor T1 above, where everything that did not receive a score in the top-k combinations is masked with 0’s.

The approach I was thinking I need is to instantiate the tensor T1 to all 0’s:

T1 = torch.zeros(batch_size, |tokens| * |tokens^2|, |num_tags|)

, and then use torch.scatter to insert the scores I computed:

T1.scatter(dim=?, index=i, src=top-k-combos_tensor)

But I am not sure how to easily index in such a way. Is there an easier way to achieve what I described here without running the entire space of computations? Or otherwise, can someone describe how I can use torch.scatter correctly here? I may be having some trouble understanding the scatter method

Update: I think masked_scatter will solve my problem.