Motivated by this example in the flex attention blog post:
bias = torch.randn(1024, 1024)
def score_mod(score, b, h, q_idx, kv_idx):
return score + bias[q_idx][kv_idx] # The bias tensor can change!
But not wanting to construct bias on the fully connected graph for memory issues.
I’ve tried something like:
num_neighbors = 50
edge_idx = torch.randint(low=0, high=N*A, size=(B, N*A, num_neighbors)).cuda()
bias = torch.randn(B, N*A, num_neighbors).cuda()
def score_mod(score, b, h, q_idx, kv_idx):
true_fn = lambda : score + bias[q_idx, torch.searchsorted(edge_idx[0, q_idx], kv_idx)]
false_fn = lambda: score
boolean = torch.isin(kv_idx, edge_idx[b,q_idx])
score = cond(boolean, true_fn, false_fn, ())
return score
block_mask = create_block_mask(window_mask, B=None, H=None, Q_LEN=N*A, KV_LEN=N*A)
flex_attention(node_h.unsqueeze(1),
node_h.unsqueeze(1),
node_h.unsqueeze(1),
block_mask=block_mask, score_mod=score_mod)
Where boolean = torch.isin(kv_idx, edge_idx[b,q_idx])
checks for a key, query pair whether the key is one of the queries neighbors, and if yes torch.searchsorted(edge_idx[0, q_idx], kv_idx)
will return the local index (in edge_bias
) corresponding to the global index (kv_idx) (we can assume here edge_idx is sorted to be increasing).
I get the error
UncapturedHigherOrderOpError: Cond doesn't work unless it is captured completely with torch.compile. Scroll up to find out what causes the graph break.
Which I don’t understand, not having much experience working with torch.compile before. Other things I have tried is using sparse tensor format but ran into errors there that seemed like it was a dead end.
The issue seems to boil down (I could be missing a lot) where I have to do two checks (1) is the query neighbors with the key and (2) a lookup mapping global index (kv_idx) to local index (the index of the neighbor in edge_bias
corresponding to kv_idx
. Would precomputing the lookup instead of using torch.search_sorted
get around the problem? Any input or thoughts would be appreciated, thank you!
I am working on torch==2.6.0 and CUDA 12.2