I am using nodes.Batcher
on a node with multiple outputs.
from torchdata.nodes import Batcher, IterableWrapper, Loader, Mapper
node = IterableWrapper(range(16))
node = Mapper(node, map_fn=lambda x: (x, x**2))
node = Batcher(node, batch_size=4)
loader = Loader(node)
print(list(loader))
This results in batches of tuples ([(0, 0), (1, 1), (2, 4), (3, 9)]
) instead of tuples of batches. What am I doing wrong here?