I am working on an SRL implementation, wherein the goal is to compute a tag for every possible predicate x span combination. This means that we can end up with an intractable scoring space of
T1 = batch_size x |tokens| x |tokens^2| x |num_tags|
where |tokens| is for predicates and |tokens^2| is for spans
In practice, papers report pruning the predicates and spans to sets of top-k candidates, and computing scores for only these. I can get these top-k no problem, but the issue is that I then need to compute a loss for a gold tensor that does not match the indices of the top-k predicates and top-k spans.
I think the solution is to simply expand the batch_size x (|top-k predicates| * |top-k spans|) x |tags| tensor back into its original size matching the shape tensor T1 above, where everything that did not receive a score in the top-k combinations is masked with 0’s.
The approach I was thinking I need is to instantiate the tensor T1 to all 0’s:
T1 = torch.zeros(batch_size, |tokens| * |tokens^2|, |num_tags|)
, and then use torch.scatter to insert the scores I computed:
T1.scatter(dim=?, index=i, src=top-k-combos_tensor)
But I am not sure how to easily index in such a way. Is there an easier way to achieve what I described here without running the entire space of computations? Or otherwise, can someone describe how I can use torch.scatter correctly here? I may be having some trouble understanding the