How to use nn.DataParallel without splitting input

Currently, my input is generated by dataloader, and each batch is not easy to split into several small batches.

I was wondering if there is any way to directly feed several batches from dataloader instead of splitting each batch.

Any comment would be appreciated.

Thank you!

update:

just got an idea:

just delete the nn.parallel.scatter function

def data_parallel(module, inputs, device_ids, output_device=None):
    if not device_ids:
        return module(input)

    if output_device is None:
        output_device = device_ids[0]

    replicas = nn.parallel.replicate(module, device_ids)
    replicas = replicas[:len(inputs)]
    outputs = nn.parallel.parallel_apply(replicas, inputs)
    return nn.parallel.gather(outputs, output_device)