Mean of non-zero values of a sparse tensor

Supposing I have a sparse coo tensor of shape 6, 100, 100, 100, C for C channels, I wish to come up with a way to do sparse pooling across dimensions (1, 2, 3), i.e. only average non-zero elements.

Are there any hacky ways of doing this?
I see torch.sparse.sum but I would need to count the number of non-zero elements to get a true average and I’m not sure how to accomplish this without going to a dense tensor first.

Hi James!

A hack, you say? Why, of course!

``````>>> import torch
>>> print (torch.__version__)
2.4.0
>>>
>>> _ = torch.manual_seed (2024)
>>>
>>> t = torch.randn (2, 3, 3) * torch.randint (2, (2, 3, 3))
>>> t
tensor([[[-0.0000, -0.0000,  0.0000],
[ 0.0000,  1.8567,  1.9776],
[-0.0000,  1.3667,  0.0000]],

[[-0.3869,  1.6579, -1.3085],
[ 0.9962,  0.9391,  0.0000],
[ 0.0000, -0.0776, -0.0000]]])
>>>
>>> t.sum ((1, 2)) / (t != 0).sum ((1, 2))   # compute "non-zero" means from dense tensor
tensor([1.7337, 0.3034])
>>>
>>> t = t.to_sparse()                        # convert to sparse tensor
>>> t
tensor(indices=tensor([[0, 0, 0, 1, 1, 1, 1, 1, 1],
[1, 1, 2, 0, 0, 0, 1, 1, 2],
[1, 2, 1, 0, 1, 2, 0, 1, 1]]),
values=tensor([ 1.8567,  1.9776,  1.3667, -0.3869,  1.6579, -1.3085,
0.9962,  0.9391, -0.0776]),
size=(2, 3, 3), nnz=9, layout=torch.sparse_coo)
>>>
>>> tnzm = t.sum ((1, 2))                    # sums of non-zero elements
>>> tcln = t.clone()
>>> tcln.values().fill_ (1.0)                # change non-zero elements of clone to 1
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1.])
>>> tcnt = tcln.sum ((1, 2))                 # counts of non-zero elements
>>>
>>> tnzm
tensor(indices=tensor([[0, 1]]),
values=tensor([5.2010, 1.8201]),
size=(2,), nnz=2, layout=torch.sparse_coo)
>>> tcnt
tensor(indices=tensor([[0, 1]]),
values=tensor([3., 6.]),
size=(2,), nnz=2, layout=torch.sparse_coo)
>>>
>>> tval = tnzm.values()
>>> tval /= tcnt.values()                    # convert sums to means
>>>
>>> tnzm                                     # non-zero means computed with sparse tensors
tensor(indices=tensor([[0, 1]]),
values=tensor([1.7337, 0.3034]),
size=(2,), nnz=2, layout=torch.sparse_coo)
``````

Best.

K. Frank

1 Like

Thanks Frank, using spconv I have an alternative (not great) solution:

``````from typing import Optional, Tuple, List

import torch, torch.nn as nn
import spconv.pytorch as spconv

def sp_conv_t_to_coo_tensor(x: spconv.SparseConvTensor) -> Tuple[torch.Tensor, torch.Tensor]:
non_zero_counts = get_num_nonzero_dim(x.indices, x.batch_size, dims=[1, 2, 3])
indices = x.indices.permute(1, 0)
features = x.features
coo_tensor = torch.sparse_coo_tensor(indices=indices, values=features)
return coo_tensor, non_zero_counts

def get_num_nonzero_dim(indices, batch_size, dims: List[int]):
non_zero_counts = torch.zeros((batch_size, len(dims)), dtype=torch.float, device=indices.device)
indices_batch = indices[:, 0]
for batch_id in range(batch_size):
for dim in dims:
unique = torch.unique(indices_batch_id[:, dim])
num_nonzero = unique.shape[0]
non_zero_counts[batch_id][dim-1] = num_nonzero
return non_zero_counts

def mean_coo_tensor_dim(x: torch.Tensor, num_nonzero: torch.Tensor, dim:int):
sum = torch.sparse.sum(x, dim=dim)
bs = x.shape[0]
num_points = sum._values().shape[0]
indices = sum._indices().permute(1, 0)
values_c = sum._values().shape[-1]
div_vals = torch.zeros((num_points,), device=x.device)
for batch_id in range(bs):
mask = indices[:, 0] == batch_id
div_vals = torch.unsqueeze(div_vals, dim=-1)
div_vals = div_vals.repeat((1, values_c))
sum._values().div_(div_vals)
return sum

def pcs_to_sparse_tensor(pcs: torch.Tensor, grid_size: float,
device_id: Optional[int], pad: int = 0) -> spconv.SparseConvTensor:
batch_size, num_points, C = pcs.shape
flattened_feats = pcs.reshape(-1, C)
pcs_f = flattened_feats[:, :3]
pcs_device = pcs.device
grid_coord = torch.div(pcs_f - pcs_f.min(0)[0], grid_size, rounding_mode="trunc").int()
repeat_vals = torch.tensor([num_points for _ in range(batch_size)]).to(pcs_device)
batch_vals = torch.arange(0, batch_size, step=1).to(pcs_device)
batch_idx = torch.repeat_interleave(batch_vals, repeat_vals)
indices = torch.cat([batch_idx.unsqueeze(-1).int(), grid_coord], dim=1).contiguous()
feats = flattened_feats
if device_id is not None:
if pcs_device != device_id:
feats = feats.to(device_id)
indices = indices.to(device_id)
sp_tensor = spconv.SparseConvTensor(
features=feats,
indices=indices,
spatial_shape=sparse_shape,
batch_size=batch_size
)
return sp_tensor

if __name__ == "__main__":
rand_sparse_tensor = pcs_to_sparse_tensor(pcs=torch.rand(6, 512, 3),