With 4 GPU and 3 dataloader workers per GPU, I end up with 4 x (1 + 3) x 32 = 512 compile worker processes that use ~400MB each, leading to a 200GB RAM consumption.
4 GPU
1 main process + 3 dataloader workers
32 compile worker process each
Why ? Can I specify the number of workers created ? Can I disable the compile workers in the dataloader workers ?
Note : I have torch2.5.1 and I use torch.compile() with no extra params (=the default ones)
It looks like it is related to the @torch.compile() decorator that I put on a function that is part of the imported files of my DataLoader.
Is my only option to remove this decorator ? Shouldn’t it start the compile worker processes when it actually starts compiling ? In my case, this function is never called in the DataLoader.
So, the problem occurs when my DataLoader imports a file with a decorated function:
@torch.compile
def do_something(A, B):
return A + B
A simple workaround to avoid creating 32 compile processes in each DataLoader worker is:
def do_something(A, B):
@torch.compile
def _do_something(A, B):
return A + B
return _do_something(A, B)
When do_something() is called, it will compile as expected. If it’s not called, no compile processes are created.
That would be nice to not have to do that.