`nodes.Batcher` on multiple outputs

I am using nodes.Batcher on a node with multiple outputs.

from torchdata.nodes import Batcher, IterableWrapper, Loader, Mapper

node = IterableWrapper(range(16))
node = Mapper(node, map_fn=lambda x: (x, x**2))
node = Batcher(node, batch_size=4)
loader = Loader(node)
print(list(loader))

This results in batches of tuples ([(0, 0), (1, 1), (2, 4), (3, 9)]) instead of tuples of batches. What am I doing wrong here?

I am missing a call to default_collate after batching.