Parallelize along samples of batch

I need to parallelize a model along the samples of a batch on CPU. The reason for that is that my model is a recursive network so I don’t have tensors which I can stack into a batch but tree structures I have to unfold during training.
Basically, the code I want to achieve looks like this.

for b in batches:
  losses = 0  # needs to be synchronized among processes
  for sample in batch:  # sample is a recursive structure, not a tensor
     # do this in parallel on batch_size-many CPU cores
     loss += model(sample)
  # sum/average all losses from the subprocesses
  loss /= num_processes
  loss.backward()
  optimizer.step()

Basically, it should work like a multiprocessing.Pool. However, I cannot use this as gradients cannot be shared along processes. I know that DistributedDataParallel exists, but I don’t see how I can use this to achieve my goal.

Thanks!

I’m not sure I understand, but couldn’t you just parallelize at the batch-level? (every process would get a subset of batches).

How can I parallelize at the batch-level? Maybe I could set the batch size to 1 and do a gradient update every 16 batches, or so…

If batch is not a tensor and you are forced to iterate over it anyways, then yes I guess that setting the batch size to 1 with DDP could do the trick

Could you please show some code here? I don’t see how to do it.

You would probably need to make your own custom data loader to handle data structures other than tensor, but the general scheme could be something like this?

model = DDP(model, ...)
sampler = DistributedSampler(dataset)
dataloader = CustomDataLoader(dataset, shuffle=False, sampler=train_sampler, batch_size=1)

for sample in dataloader:
  loss = model(sample)
  optimizer.zero_grad()
  loss.backward() #this is synchronized automatically
  optimizer.step()

Thank you. I implemented this approach. But when I set batch_size=1, it performs a gradient update after every sample, right? That is not what I want, I want larger batches. Is there a way to achieve this?
When I set batch_size=16, every process gets all samples of this batch. Maybe I misunderstood DDP, but shouldn’t it split the batch into subbatches < 16 and distribute these to the workers? The batches are python lists in my case.

When using DDP you are “simulating” a total batch size equal to batch_size*world_size. If you set batch size to 1 and spawn 16 workers (while using the DistributedSampler) you are synchronizing gradients from 16 different samples at the same time.

For more detail about DDP refer to Distributed Data Parallel — PyTorch master documentation