Currently, my input is generated by dataloader, and each batch is not easy to split into several small batches.
I was wondering if there is any way to directly feed several batches from dataloader instead of splitting each batch.
Any comment would be appreciated.
Thank you!
update:
just got an idea:
just delete the nn.parallel.scatter function
def data_parallel(module, inputs, device_ids, output_device=None):
if not device_ids:
return module(input)
if output_device is None:
output_device = device_ids[0]
replicas = nn.parallel.replicate(module, device_ids)
replicas = replicas[:len(inputs)]
outputs = nn.parallel.parallel_apply(replicas, inputs)
return nn.parallel.gather(outputs, output_device)