Hi James!
A hack, you say? Why, of course!
>>> import torch
>>> print (torch.__version__)
2.4.0
>>>
>>> _ = torch.manual_seed (2024)
>>>
>>> t = torch.randn (2, 3, 3) * torch.randint (2, (2, 3, 3))
>>> t
tensor([[[-0.0000, -0.0000, 0.0000],
[ 0.0000, 1.8567, 1.9776],
[-0.0000, 1.3667, 0.0000]],
[[-0.3869, 1.6579, -1.3085],
[ 0.9962, 0.9391, 0.0000],
[ 0.0000, -0.0776, -0.0000]]])
>>>
>>> t.sum ((1, 2)) / (t != 0).sum ((1, 2)) # compute "non-zero" means from dense tensor
tensor([1.7337, 0.3034])
>>>
>>> t = t.to_sparse() # convert to sparse tensor
>>> t
tensor(indices=tensor([[0, 0, 0, 1, 1, 1, 1, 1, 1],
[1, 1, 2, 0, 0, 0, 1, 1, 2],
[1, 2, 1, 0, 1, 2, 0, 1, 1]]),
values=tensor([ 1.8567, 1.9776, 1.3667, -0.3869, 1.6579, -1.3085,
0.9962, 0.9391, -0.0776]),
size=(2, 3, 3), nnz=9, layout=torch.sparse_coo)
>>>
>>> tnzm = t.sum ((1, 2)) # sums of non-zero elements
>>> tcln = t.clone()
>>> tcln.values().fill_ (1.0) # change non-zero elements of clone to 1
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1.])
>>> tcnt = tcln.sum ((1, 2)) # counts of non-zero elements
>>>
>>> tnzm
tensor(indices=tensor([[0, 1]]),
values=tensor([5.2010, 1.8201]),
size=(2,), nnz=2, layout=torch.sparse_coo)
>>> tcnt
tensor(indices=tensor([[0, 1]]),
values=tensor([3., 6.]),
size=(2,), nnz=2, layout=torch.sparse_coo)
>>>
>>> tval = tnzm.values()
>>> tval /= tcnt.values() # convert sums to means
>>>
>>> tnzm # non-zero means computed with sparse tensors
tensor(indices=tensor([[0, 1]]),
values=tensor([1.7337, 0.3034]),
size=(2,), nnz=2, layout=torch.sparse_coo)
Best.
K. Frank