Sparse matrix acceleration

I am currently engaged in the development of a novel cross-attention mechanism for transformers. In this context, consider the shape of q and k as (b, n, l, d), where b represents the batch number, n denotes the head number, l signifies the sequence length, and d stands for the feature dimensions.

The code mechanism employed to compute attention involves the application of the F.relu activation function to both q and k, followed by reshaping operations. Specifically, it can be expressed as follows:

F.relu(q).reshape(b, n * l, d) @ F.relu(k).reshape(b, n * l, d)

However, the reshape operation results in the multiplication of two large matrices, incurring a substantial GPU memory cost. This presents a notable drawback in terms of computational efficiency and resource utilization.

The aforementioned attention function can be conceptualized as the multiplication between two sparse matrices, considering the involvement of the ReLU function.
To mitigate the memory consumption issue, an alternative approach could be explored: the possibility of performing Batched Sparse Matrix - Batched Sparse Matrix multiplication.

I am curious about your thoughts on whether this approach seems feasible.