Hey all!
First, I want to thank all Torch developers & contributors for a great piece of software and the community for some very informative discussions and resources.
I’m working on a project where I’d like to use Hogwild! / ASGD [1, 2] to replicate / approximate Mikolov’s word2vec [3]. Overall, the code structure can be described as:
- The main process initializes a model and adds it to shared memory (
model.share_memory()
). - The main process also computes certain statistics of the entire dataset, which are shared with the background processes for putting together instances to train with (see [3] and Mikolov’s papers)
- The main process starts
P
background processes using Python’smultiprocessing
module. - Each background process is responsible for training the model with a (roughly equal) subset of the input data, as shown in the torch example [2]. For example, if
P=5
, each bg process trains with 1/5 of the original data. IfP=20
, each process will work with 1/20-th of the data, and so on. - Just like the sample code [2], there is no explicit locking.
I’d like all this to be done in CPU. Below is some sample code for what is going on
Sample Code
# This is what the main process code looks like
def training(n_procs: int, n_docs: int, n_words: int, emb_size: int):
model = MyDoc2Vec(n_docs, n_words, emb_size) # internally initializes nn.Parameters
model.share_memory()
docs_per_proc = math.ceil(n_docs / n_procs)
bg_processes = []
for proc_i in range(n_procs):
start_doc = proc_i * docs_per_proc
end_doc = min(start_doc + docs_per_proc, n_docs)
p = mp.Process(target=bg_training, args=(model, start_doc, end_doc), daemon=True)
bg_processes.append(p)
p.start()
for p in bg_processes:
p.join()
return
# This is what the background process code looks like
# Though not shown here, the dataset to read from is available globally
# The background processes do not have to read the dataset again before doing any work.
def bg_training(model: nn.Module, start_idx: int, end_idx: int):
# torch.set_num_threads(1) # See Q1 for details
optimizer = torch.optim.SGD(model.parameters(), lr=INITIAL_LR) # INITIAL_LR is an imported constant
loss_func = NegativeSampling()
for doc_idx in range(start_idx, end_idx):
# Single class with 3 np.array (in sync). Can be > BATCH_SIZE.
# See _NCEBatch in https://github.com/inejc/paragraph-vectors/blob/master/paragraphvec/data.py#L325C7-L325C16 as an example
doc_batches = get_batches_from_doc(doc_idx)
losses = []
offset = 0
while offset < len(doc_batches):
batches_subset = doc_batches.get(offset, offset+BATCH_SIZE) # BATCH_SIZE is an imported constant
## If I turn this section OFF, all processes show up as active in htop (see Q2)
# Converts elements to torch.LongTensor. No CUDA involved
batch_tensors = batch_subset.to_tensor()
# Scores from the model & loss computation
scores = model.forward(batch_tensors) # <- See Q1: deadlock here. No locking in the model
loss = loss_func.forward(scores)
loss_value = loss.item() # to help with reporting in the end
losses.append(loss_value)
# backprop section
model.zero_grad()
loss.backward()
optimizer.step()
offset += BATCH_SIZE
# Do reporting here (outside the loop)
return
Based on my observations so far, I have a few questions that I could use some input / help with:
Q1: Deadlocks in Background Processes
In my first attempt(s), I was hit by deadlocks in the background processes as discussed in [4]. This was eventually addressed by switching to intel-numpy
. Setting torch.set_num_threads(1)
(suggested in the link) or the equivalent OMP_NUM_THREADS
“worked”, in the sense that the deadlocks were gone, but the efficiency was abysmal (6h+ for a single epoch on a small-ish dataset).
Switching to PyTorch 2 did not help in my case.
Are these the only alternatives to this problem (deadlocks)? Is there some better approach?
Q2: Performance of Background Processes
As mentioned above, using set_num_threads(1)
(in the bg process) or setting OMP_NUM_THREADS=1
had terrible performance.
I should mention here that the machine I’m using has sufficient resources both in RAM (say, 1Tb) and CPU (say, 128 CPUs). No other resource-hungry process runs on it. I was also explicitly setting the env var OMP_NUM_THREADS
, to prevent the background processes from assuming they could each spawn a number of threads equal to the machine’s CPUs (which resulted in resource contention and a lot of kernel-space computation).
As a broader pattern, I noticed that, everything else being the same, I was getting better (lower) wall clock time when using fewer processes with more threads, rather than the other way around. For example, 3 processes with 48 threads were considerably faster than 24 processes with 6 threads. This is in line with the above observation that n_procs=96
and OMP_NUM_THREADS=1
had terrible performance, even though, theoretically, we should be seeing (close to) “perfect” parallelism (I understand this is idealizing things).
Equally surprising, but in line with the above, when using more processes and fewer threads, htop
would report underutilization of the available CPUs. In particular:
- For
n_procs=3
andOMP_NUM_THREADS=48
,htop
reports ~96 CPUs as active (2*48). - For
n_procs=24
andOMP_NUM_THREADS=6
,htop
would report ~12 CPUs as active (2*6). - For
n_procs=2
andOMP_NUM_THREADS=64
,htop
would report ~128 CPUs as active (2*64).
It’s almost as if only 2 background processes were “really” doing any work. The last combination suggests that there is really enough work to go around. At the same time, if I turn off the torch
-related parts of the code (basically, reading through the doc, sampling, etc), more processes appear active in htop
Is this behavior expected, known or otherwise sound familiar to anyone? Separately, is there some guidance for the number of threads an (multiprocess) application should have? I understand there’s no silver bullet here, but a rule of thumb (depending on operations involved) would be super helpful.
Q3: Doc2Vec Hogwild! with Python
As mentioned in the original paper [5], Hogwild! is well suited when the optimization problem is sparse. For W2V & Doc2Vec, this translates to updating (very) few words & docs at a time.
If I’m piecing things together properly & using the verbiage of an open source implementation for Doc2Vec [6], we want a small batch_size
. Each batch contains 1 positive and several negatives. The batch_size
controls how many such batches we process in a single forward + optimization pass.
Very small (e.g, 5, 20, 100) or even small-medium batch sizes (e.g., 1000), resulted in extremely poor performance (separately from the above comments): a very long time was spent in forward passes of the model & computing the loss. Significant time was also spent in backprop. Though I didn’t profile this, I think the time was really just spent in the Python interpreter, which only involves torch operations. Links https://github.com/inejc/paragraph-vectors/blob/master/paragraphvec/loss.py and https://github.com/inejc/paragraph-vectors/blob/master/paragraphvec/models.py#L31 are good reference points about what’s involved.
Increasing the batch size to 10000 dramatically dropped the wall clock time; the processing time per doc was ~2-3 orders of magnitude better and the performance (efficiency) kept improving beyond that, though the difference was less noticeable. Granted, I didn’t do a complete scan of batch sizes to know where performance really started taking off. Unfortunately, I think that as we increase the batch size, the more likely we are to invalidate the assumption of a “sparse optimization problem” and sparse updates that Hogwild! relies on. Basically, larger batch sizes increase the probability that different processes – with multiple parallel threads – will operate on the same word at the same time, which we want to avoid.
With the above in mind, a few questions:
-
Could the slowdown be attributed to something other than the interpreter? E.g., is there any additional initialization happening by PyTorch internally (e.g., thread pools, other) that might explain the perf difference? I understand that PyTorch will build a computational graph internally to track things, but the overhead (2-3 orders of magnitude) seems too much (I think) for such things. Might there be something else going on I should look at?
-
Is there an alternative strategy (API, implementation) or design that would allow me to use smaller batch sizes without a performance penalty? E.g., would I need to use the C++ API instead? Or is there some obvious torch API I should be using here?
-
Is my understanding correct that a larger batch size (as described above) is counter to the idea of a “sparse optimization problem” in terms of how this gets implemented / executed in PyTorch?
Configuration
- Ubuntu Linux 20.04
- Python 3.9
- Torch 1.13.1+cu117
- numpy 1.21.6 // intel-numpy 1.21.4
- Python
multiprocessing
. The drop-in replacementtorch.multiprocessing
had the same behavior.
Sorry for the long post. Let me know if clarifications are needed. Thanks in advance!
References
[1] Multiprocessing best practices — PyTorch 2.0 documentation
[2] https://github.com/pytorch/examples/blob/main/mnist_hogwild/train.py
[3] https://github.com/tmikolov/word2vec/blob/master/word2vec.c
[4] Deadlock with multiprocessing (using fork) and OpenMP / PyTorch should warn after OMP and fork that multithreading may be broken · Issue #17199 · pytorch/pytorch · GitHub
[5] https://people.eecs.berkeley.edu/~brecht/papers/hogwildTR.pdf
[6] GitHub - inejc/paragraph-vectors: 📄 A PyTorch implementation of Paragraph Vectors (doc2vec).