Currently, my input is generated by dataloader, and each batch is not easy to split into several small batches.
I was wondering if there is any way to directly feed several batches from dataloader instead of splitting each batch.
Any comment would be appreciated.
just got an idea:
just delete the nn.parallel.scatter function
def data_parallel(module, inputs, device_ids, output_device=None): if not device_ids: return module(input) if output_device is None: output_device = device_ids replicas = nn.parallel.replicate(module, device_ids) replicas = replicas[:len(inputs)] outputs = nn.parallel.parallel_apply(replicas, inputs) return nn.parallel.gather(outputs, output_device)