Sparse Matrices and Broadcasting

Is there an obvious trick to working around the limitation of broadcasting on Sparse matrices?

Take for example:

a = torch.rand((4)).to_sparse()
b = torch.rand((4, 4)).to_sparse()
a * b

which throws:

RuntimeError: sparse_binary_op_intersection_cpu(): expects sparse inputs with equal dimensionality, number of sparse dimensions, and shape of sparse dimensions

I can’t use to_dense() as these are massive tensors, and I also do need require_grad with respect to a. Any workaround?

Thanks!