How to collect whole batch at once?

You could use a BatchSampler and pass all indices directly to Dataset.__getitem__ which would allow you to load multiple samples. This code shares an example.