Too many compile workers and huge RAM consumption

With 4 GPU and 3 dataloader workers per GPU, I end up with 4 x (1 + 3) x 32 = 512 compile worker processes that use ~400MB each, leading to a 200GB RAM consumption.

  • 4 GPU
  • 1 main process + 3 dataloader workers
  • 32 compile worker process each

Why ? Can I specify the number of workers created ? Can I disable the compile workers in the dataloader workers ?

Note : I have torch2.5.1 and I use torch.compile() with no extra params (=the default ones)

Are you calling torch.compile inside your Dataset or DataLoader? If not, I would not expect to see any additional workers related to torch.compile.

Thanks @ptrblck !

It looks like it is related to the @torch.compile() decorator that I put on a function that is part of the imported files of my DataLoader.

Is my only option to remove this decorator ? Shouldn’t it start the compile worker processes when it actually starts compiling ? In my case, this function is never called in the DataLoader.

So, the problem occurs when my DataLoader imports a file with a decorated function:

@torch.compile
def do_something(A, B):
    return A + B

A simple workaround to avoid creating 32 compile processes in each DataLoader worker is:

def do_something(A, B):
    @torch.compile
    def _do_something(A, B):
        return A + B
    return _do_something(A, B)

When do_something() is called, it will compile as expected. If it’s not called, no compile processes are created.
That would be nice to not have to do that.