RuntimeError: reshape is not implemented for sparse tensors

Is there an efficient way to reshape a sparse tensor? I’m using pytorch 1.7 and am trying to use a sparse tensor where I’ve been using a dense tensor (which is extremely sparse). The code fails when it hits reshape as there is no implementation for a sparse tensor. Are there any plans to implement this?

1 Like
def sparse_reshape(sparse_tensor, size):
    mask = torch.sparse_coo_tensor(indices=sparse_tensor.indices(),
                                   values=torch.ones(sparse_tensor._nnz(),
                                                     device=sparse_tensor.device,
                                                     dtype=torch.bool),
                                   size=sparse_tensor.size())
    mask = mask.to_dense().reshape(size)

    sparse_tensor = torch.sparse_coo_tensor(indices=mask.to_sparse().indices(),
                                            values=sparse_tensor.values(),
                                            size=size)
    return sparse_tensor.coalesce()
    pass