Hello everyone,
I’m training a sequence-based model on a single machine with 1 GPU and 16 CPU cores. My loss function is computationally expensive and performs best on the CPU. The reason it performs well on a CPU is that computing this particular loss is a sequential process and, while it can’t be parallelized/GPU optimized for a single data point, it can be parallelized across a batch of samples. The rest of the model works well on a GPU.
I’d like to know how to parallelize the loss of my model over a batch of samples and CPU cores while keeping the rest of my model on the GPU. I know this will involve transferring the model’s predictions from GPU - > CPU, but I believe the performance increase by using CPUs for the loss will outweigh this fact.
I’ve tried torch.multiprocessing
as follows (where single_item_loss
is the loss function I described above, that takes a single model prediction from the batch of predictions and returns a single Torch float tensor representing the loss):
with multiprocessing.Pool(multiprocessing.cpu_count()) as p:
results = p.map(single_item_loss, predictions)
but this yields the error:
RuntimeError: Cowardly refusing to serialize non-leaf tensor which requires_grad, since autograd does not support crossing process boundaries. If you just want to transfer the data, call detach() on the tensor before serializing (e.g., putting it on the queue).
I’ve also tried using the joblib
library:
results = Parallel(n_jobs=multiprocessing.cpu_count(),
backend="loky")(delayed(single_item_loss)(pred) for pred in predictions)
but this must sever each of the idividual losses from the computation graph, because calling torch.mean(torch.stack(results)).backward()
(and later optimizer.step()
) has no effect and the model does not train.
I’ve also tried calling loss.backward()
before returning loss
from my single_item_loss
function, but this also does not work.
I’m eager to hear any feedback and work to solve this problem. Evidently this is not a very common situation as GPUs are almost always preferred, but in this case, I believe the custom loss function warrants this special treatment. Thank you very much for your help!