I think one potential issue is that numerical CPU backends used by PyTorch such as MKL/oneDNN may already use more than one thread per process by default. Do you see scaling behavior that is closer to what is expected when setting e.g., OMP_NUM_THREADS=1
and MKL_NUM_THREADS=1
?
See also:
https://pytorch.org/docs/stable/generated/torch.set_num_threads.html