Mean of non-zero values of a sparse tensor

Supposing I have a sparse coo tensor of shape 6, 100, 100, 100, C for C channels, I wish to come up with a way to do sparse pooling across dimensions (1, 2, 3), i.e. only average non-zero elements.

Are there any hacky ways of doing this?
I see torch.sparse.sum but I would need to count the number of non-zero elements to get a true average and I’m not sure how to accomplish this without going to a dense tensor first.

Hi James!

A hack, you say? Why, of course!

>>> import torch
>>> print (torch.__version__)
2.4.0
>>>
>>> _ = torch.manual_seed (2024)
>>>
>>> t = torch.randn (2, 3, 3) * torch.randint (2, (2, 3, 3))
>>> t
tensor([[[-0.0000, -0.0000,  0.0000],
         [ 0.0000,  1.8567,  1.9776],
         [-0.0000,  1.3667,  0.0000]],

        [[-0.3869,  1.6579, -1.3085],
         [ 0.9962,  0.9391,  0.0000],
         [ 0.0000, -0.0776, -0.0000]]])
>>>
>>> t.sum ((1, 2)) / (t != 0).sum ((1, 2))   # compute "non-zero" means from dense tensor
tensor([1.7337, 0.3034])
>>>
>>> t = t.to_sparse()                        # convert to sparse tensor
>>> t
tensor(indices=tensor([[0, 0, 0, 1, 1, 1, 1, 1, 1],
                       [1, 1, 2, 0, 0, 0, 1, 1, 2],
                       [1, 2, 1, 0, 1, 2, 0, 1, 1]]),
       values=tensor([ 1.8567,  1.9776,  1.3667, -0.3869,  1.6579, -1.3085,
                       0.9962,  0.9391, -0.0776]),
       size=(2, 3, 3), nnz=9, layout=torch.sparse_coo)
>>>
>>> tnzm = t.sum ((1, 2))                    # sums of non-zero elements
>>> tcln = t.clone()
>>> tcln.values().fill_ (1.0)                # change non-zero elements of clone to 1
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1.])
>>> tcnt = tcln.sum ((1, 2))                 # counts of non-zero elements
>>>
>>> tnzm
tensor(indices=tensor([[0, 1]]),
       values=tensor([5.2010, 1.8201]),
       size=(2,), nnz=2, layout=torch.sparse_coo)
>>> tcnt
tensor(indices=tensor([[0, 1]]),
       values=tensor([3., 6.]),
       size=(2,), nnz=2, layout=torch.sparse_coo)
>>>
>>> tval = tnzm.values()
>>> tval /= tcnt.values()                    # convert sums to means
>>>
>>> tnzm                                     # non-zero means computed with sparse tensors
tensor(indices=tensor([[0, 1]]),
       values=tensor([1.7337, 0.3034]),
       size=(2,), nnz=2, layout=torch.sparse_coo)

Best.

K. Frank

1 Like

Thanks Frank, using spconv I have an alternative (not great) solution:

from typing import Optional, Tuple, List

import torch, torch.nn as nn
import spconv.pytorch as spconv

def sp_conv_t_to_coo_tensor(x: spconv.SparseConvTensor) -> Tuple[torch.Tensor, torch.Tensor]:
    non_zero_counts = get_num_nonzero_dim(x.indices, x.batch_size, dims=[1, 2, 3])
    indices = x.indices.permute(1, 0)
    features = x.features
    coo_tensor = torch.sparse_coo_tensor(indices=indices, values=features)
    return coo_tensor, non_zero_counts


def get_num_nonzero_dim(indices, batch_size, dims: List[int]):
    non_zero_counts = torch.zeros((batch_size, len(dims)), dtype=torch.float, device=indices.device)
    indices_batch = indices[:, 0]
    for batch_id in range(batch_size):
        for dim in dims:
            mask = indices_batch == batch_id
            indices_batch_id = indices[mask]
            unique = torch.unique(indices_batch_id[:, dim])
            num_nonzero = unique.shape[0]
            non_zero_counts[batch_id][dim-1] = num_nonzero
    return non_zero_counts


def mean_coo_tensor_dim(x: torch.Tensor, num_nonzero: torch.Tensor, dim:int):
    sum = torch.sparse.sum(x, dim=dim)
    bs = x.shape[0]
    num_points = sum._values().shape[0]
    indices = sum._indices().permute(1, 0)
    values_c = sum._values().shape[-1]
    div_vals = torch.zeros((num_points,), device=x.device)
    for batch_id in range(bs):
        mask = indices[:, 0] == batch_id
        div_vals[mask] = num_nonzero[batch_id][dim-1]
    div_vals = torch.unsqueeze(div_vals, dim=-1)
    div_vals = div_vals.repeat((1, values_c))
    sum._values().div_(div_vals)
    return sum

def pcs_to_sparse_tensor(pcs: torch.Tensor, grid_size: float,
                         device_id: Optional[int], pad: int = 0) -> spconv.SparseConvTensor:
    batch_size, num_points, C = pcs.shape
    flattened_feats = pcs.reshape(-1, C)
    pcs_f = flattened_feats[:, :3]
    pcs_device = pcs.device
    grid_coord = torch.div(pcs_f - pcs_f.min(0)[0], grid_size, rounding_mode="trunc").int()
    sparse_shape = torch.add(torch.max(grid_coord, dim=0).values, pad).tolist()
    repeat_vals = torch.tensor([num_points for _ in range(batch_size)]).to(pcs_device)
    batch_vals = torch.arange(0, batch_size, step=1).to(pcs_device)
    batch_idx = torch.repeat_interleave(batch_vals, repeat_vals)
    indices = torch.cat([batch_idx.unsqueeze(-1).int(), grid_coord], dim=1).contiguous()
    feats = flattened_feats
    if device_id is not None:
        if pcs_device != device_id:
            feats = feats.to(device_id)
            indices = indices.to(device_id)
    sp_tensor = spconv.SparseConvTensor(
        features=feats,
        indices=indices,
        spatial_shape=sparse_shape,
        batch_size=batch_size
    )
    return sp_tensor

if __name__ == "__main__":
      rand_sparse_tensor = pcs_to_sparse_tensor(pcs=torch.rand(6, 512, 3), 
      grid_size=0.01, device_id=2, pad=0)
      a, b = sp_conv_t_to_coo_tensor(rand_sparse_tensor)
      mean = mean_coo_tensor_dim(a, b, 1)

But I think this breaks the gradient, so not too useful.