FlexAttention with sparse edge bias

Motivated by this example in the flex attention blog post:

bias = torch.randn(1024, 1024)
def score_mod(score, b, h, q_idx, kv_idx):
    return score + bias[q_idx][kv_idx] # The bias tensor can change!

But not wanting to construct bias on the fully connected graph for memory issues.

I’ve tried something like:

num_neighbors = 50
edge_idx = torch.randint(low=0, high=N*A, size=(B, N*A, num_neighbors)).cuda()
bias = torch.randn(B, N*A, num_neighbors).cuda()

def score_mod(score, b, h, q_idx, kv_idx):
    true_fn = lambda : score + bias[q_idx, torch.searchsorted(edge_idx[0, q_idx], kv_idx)] 
    false_fn = lambda: score 
    boolean = torch.isin(kv_idx, edge_idx[b,q_idx])
    score = cond(boolean, true_fn, false_fn, ()) 
    return score 

block_mask = create_block_mask(window_mask, B=None, H=None, Q_LEN=N*A, KV_LEN=N*A)
flex_attention(node_h.unsqueeze(1), 
               node_h.unsqueeze(1), 
               node_h.unsqueeze(1), 
               block_mask=block_mask, score_mod=score_mod)

Where boolean = torch.isin(kv_idx, edge_idx[b,q_idx]) checks for a key, query pair whether the key is one of the queries neighbors, and if yes torch.searchsorted(edge_idx[0, q_idx], kv_idx) will return the local index (in edge_bias) corresponding to the global index (kv_idx) (we can assume here edge_idx is sorted to be increasing).

I get the error

UncapturedHigherOrderOpError: Cond doesn't work unless it is captured completely with torch.compile. Scroll up to find out what causes the graph break.

Which I don’t understand, not having much experience working with torch.compile before. Other things I have tried is using sparse tensor format but ran into errors there that seemed like it was a dead end.

The issue seems to boil down (I could be missing a lot) where I have to do two checks (1) is the query neighbors with the key and (2) a lookup mapping global index (kv_idx) to local index (the index of the neighbor in edge_bias corresponding to kv_idx. Would precomputing the lookup instead of using torch.search_sorted get around the problem? Any input or thoughts would be appreciated, thank you!

I am working on torch==2.6.0 and CUDA 12.2