Is there an efficient way to reshape a sparse tensor? I’m using pytorch 1.7 and am trying to use a sparse tensor where I’ve been using a dense tensor (which is extremely sparse). The code fails when it hits reshape as there is no implementation for a sparse tensor. Are there any plans to implement this?
def sparse_reshape(sparse_tensor, size):
mask = torch.sparse_coo_tensor(indices=sparse_tensor.indices(),
values=torch.ones(sparse_tensor._nnz(),
device=sparse_tensor.device,
dtype=torch.bool),
size=sparse_tensor.size())
mask = mask.to_dense().reshape(size)
sparse_tensor = torch.sparse_coo_tensor(indices=mask.to_sparse().indices(),
values=sparse_tensor.values(),
size=size)
return sparse_tensor.coalesce()
pass
1 Like