Hello, I’m trying to build a custom class/reformulation of scatter_reduce. The intent is that, rather than storing a new index tensor for every forward call to scatter_reduce, one creates an instance of this class to store the index. Then, based off of the shape of the data the index can be expanded to match the input. Any subsequent calls to this operator during a forward-pass do not need to store a new copy of the index tensor.
Why I want this
I’m working on a problem for LArTPC readouts (see arXiv entries 1802.08709, 2007.12743 for some info)
I want to go from 3x 2D projections (views) of the 3D data in slices of time, to a single “overlapping view” (i.e. where a trio of wires overlap) again in slices of time, and then back again.
That “back again” point is where the scattering comes in. One can build a mapping from the readout index, create the overlap view, perform some operation, and then scatter back to the readout index (since a single readout will engage in multiple crossings for a single time slice, we need to aggregate from the operated-upon overlap view).
For one style of detector, we have 2560 readout elements, that form 1.3M overlaps, and we operate on the order of 100s of time slices for a single batch. The size of this is quite large, and the traditional scattering implementation prevents flexibility for i.e. performing the scattering multiple times on chunks of time-slices. So I’ve started to develop this “memorized” scattering.
What I’ve done
Note: Admittedly, I used gemini to explore this idea. You can see the chat here. There are some questions about packaging with uv as well, so you can ignore those. It eventually suggested to make a c++ class inheriting from torch::CustomClassHolder (to hold the index and to perform the op) and, in order to work with autograd, create a torch::autograd::Node-inherited struct to define the backwards pass + gradient edges. I note that, as seen in that gemini chat, this isn’t the usual way to do c++ extension, so if this is entirely the wrong approach, that would be useful feedback to hear ![]()
Implementation
As of writing, the main forward implementation is the following. Note, I only start with the “sum” reduction, though “max/amax” might be useful for my use-case. Apologies for the sloppy comments/to-dos
torch::Tensor forward(torch::Tensor src, std::vector<int64_t> results_shape) {
// Zero-copy expansion in forward
//Copy of index
auto expanded_index = index_1d;
// std::cout << "Expanded index is on " << expanded_index.device() << std::endl;
//TODO -- make this smarter and for a given dimension.
for (size_t i = 0; i < (src.sizes().size()-1); ++i) { //All but last dim
expanded_index = expanded_index.unsqueeze(0);
}
expanded_index = expanded_index.expand_as(src);
auto result = at::zeros(results_shape, src.options());
{
at::NoGradGuard no_grad;
result.scatter_reduce_(-1, expanded_index, src, "sum", false);
}
if (GradMode::is_enabled() && src.requires_grad()) {
auto grad_fn = std::make_shared<MyScatterNode>();
grad_fn->persistent_index = index_1d;
grad_fn->input_shape = src.sizes().vec();
grad_fn->set_next_edges(collect_next_edges(src));
create_gradient_edge(result, grad_fn);
}
return result;
}
Preliminary Testing
I’ve done some testing in an actual training setup, and I’m a bit disappointed with its speed. One suspicion I have is it’s because I’m not making real use of the atomic operations. Here’s why:
We have 3 readout planes that overlap to form the ‘overlap-view’. I want to distribute information from the overlap view to these three planes. I store them in a ‘monolithic’ tensor of shape (2560, N time slices). The overlap-view indices are of shape (3, N overlaps): each plane gets an indices tensor and thus it’s own operation. I end up doing this:
op1 = my_custom_op(indices1)
op2 = my_custom_op(indices2)
op3 = my_custom_op(indices3)
output = torch.zeros(…)
output += op1(overlap_view[0], output.shape)
output += op2(overlap_view[1], output.shape)
output += op3(overlap_view[2], output.shape)
Rather than doing, i.e.
op1(overlap_view[0], output)
op2(overlap_view[1], output)
op3(overlap_view[2], output)
The reason: I’m unsure of what to do with the gradient edges in the case of in-place operations.
I did testing with the profiling script in the repo using the actual in-place operations (the second example above) and saw a decent speedup.
Scatter then add:
6 ops, overall time (including profiler warmup ~20ms) ~ 165ms
In-place scatter:
6 ops, overall time (including profiler warmup ~20ms) ~ 135ms
So about a 30ms reduction (~20% speedup).
It seems like actually using the in-place ops is useful, but I’m confused how to setup the gradient edges
Building + testing
If you’d like to build and run the test above:
- Git repo: GitHub - calcuttj/MemorizedScatterOp
- building:
uv run --with ‘torch==2.8.0’ --with setuptools python setup.py bdist_wheel - running test:
uv run --with ‘torch==2.8.0’ python scripts/profile_scatter.py --indices “cells_chans.pt:cells_chans_f0:0” --library build/lib.linux-x86_64-cpython-311/my_custom_ops_lib.cpython-311-x86_64-linux-gnu.so --device ‘cuda’
Wrapping up
So that all being said, does this seem like a useful path to go down? I can imagine there might be other approaches to this problem that might be better, whether that’s formatting the data or problem differently, etc. If this is a decent path, I’d love some help understanding any useful optimizations that I might be missing + implementation help with the gradient edges if it comes to it.
Thanks to anyone who reads through this, I know I’ve dumped a lot here!