How should I use torch.compile properly?

I have multiple questions about how to use torch.compile properly. I found the tutorial/documentation lackluster. I will try and list all of them down including those I found answered in this forum but are missing from the tutorial, for future readers.

  • How to use torch.compile with a non-trivial nn.Module? This post on stackoverflow perfectly sums up my question. Namely, should torch.compile be called manually on every sub-module I wish to compile, or does it handle that automatically? More generally, how does torch.compile behave when used on a nn.Module? Does it compile only the forward method? In my case, similar to OP in that stackoverflow post, I observe zero performance difference when using torch.compile, even though my model has 27M parameters and I am using one of the recommended GPUs in the tutorial (V100). I also note torch.compile exits immediately, as if it were doing nothing.

  • Should I cast to DDP before or after torch.compile? In the DistributedDataParallel tutorial (can’t post another link due to new forum user limitation), torch.compile is called after on the DDP model. But in this post, it is recommended to call torch.compile before. So what is the right answer?

  • Should I move the model to device before or after torch.compile? In that same post, it is recommended to cast to CUDA before calling torch.compile. But this is not mentioned in the tutorial even though this seems like an pretty essential consideration.

2 Likes

Hi

  • I answered the question on stack overflow, one aspect that’s important to note is that V100 improvements are alright for torch.compile but the real speedups will come from A100 or A10G. And yes indeed nothing will happen when you torch.compile only, the compilation will happen at the time of the first inference. The name is sorta bad, it should be called torch.jit but that was already taken XD
  • Regarding your question on distributed, today the distributed support is not very fleshed. There’s a tradeoff if you compile the DDP module then torch.compile should be able to trace the communication and do more optimizations there but it doesnt today so stuff is likely to break so it’s safer to compile the inner module but this will likely evolve in our next releases
  • It’s not a huge deal either way, inductor would just prefer it if you don’t have too many device copies but that’s not gonna break it or cripple perf

Thank you so much for the quick reply!

I have some follow-up questions. I have a trainer class that takes an initialized nn.Module as argument. The model defines train_step and val_step methods that are called by the trainer during training, Ă  la Lightning. I want to be able to pass a compile boolean argument to my trainer to call torch.compile on the model. Something like this (code heavily simplified for this question):

class MyTrainer:
    def __init__(self, model, compile, ...):
        if compile:
            model = torch.compile(model)
        self.model = model
        ...  # initialize dataloaders

    def run(self):
        for epoch in range(self.epochs):
            self.model.train()
            for batch in self.train_dataloader:
                train_loss = self.model.train_step(batch)
            ...  # accumulate and log train loss
            self.model.eval()
            for batch in self.val_dataloader:
                val_loss = self.model.val_step(batch)
            ...  # accumulate and log val loss


class MyModel:
    def __init__(self, ...):
        ...  # initialize model, optimizer and criterion

    def train_step(self, batch):
        self.optimizer.zero_grad()
        loss = self._step(batch)
        loss.backward()
        self.optimizer.step()
        return loss

    def val_step(self, batch):
        loss = self._step(batch)
        return loss

    def _step(self, batch):
        inputs, labels = batch[:, 0], batch[:, 1]
        outputs = self(inputs)
        loss = self.criterion(outputs, labels)
        return loss.mean()

This currently does nothing. I believe it’s because by calling model = torch.compile(model) and assigning self.model = model, I lose reference to the original model and the function to compile thus becomes self.model.__call__ instead of the original model’s __call__. So when calling train_step and val_step, as these fall back to the original model’s implementation, they use the un-optimized model’s __call__. Am I correct?

I found out that using torch.compile this way instead does do something. Namely I compile the forward method manually:

class MyTrainer:
    def __init__(self, model, compile, ...):
        if compile:
            model.forward = torch.compile(model.forward)
        self.model = model
        ...  # initialize dataloaders

Alternatively compiling the train_step and val_step methods this way also does something, though the result is different as per the torch_dynamo logs:

class MyTrainer:
    def __init__(self, model, compile, ...):
        if compile:
            model.train_step = torch.compile(model.train_step)
            model.val_step = torch.compile(model.val_step)
        self.model = model
        ...  # initialize dataloaders

Are these valid approaches for my use case? Should one be preferred over the other? Note that in the second solution, the compiled function contains the optimizer.zero_grad(), loss.backward() and optimizer.step() calls. Is it wrong to compile these?

On another note, from what I understand from the logs produced by torch._dynamo, compilation happens not only on the first batch, but on every batch during the first epoch. Is this normal? Starting the second epochs, no more logs are produced.

Yeah integrations with training loop providers can be funky, this is a different take on the API https://github.com/pytorch/pytorch/pull/97565 - lmk what you think because i tried to include in core exactly what you’ve done

Mmh I see, so if I understand correctly from here,

  • You define an in-place compile method for the model which assigns the optimized module to an attribute of the model
  • When calling the model, you check if the optimized module attribute has been assigned and use it if it was
  • When saving the model, you make sure the optimized model state is not dumped
  • If persistent_compilation is True, the model is compiled upon loading

Is this correct? I can foresee this can be easily applied to my framework by extended my model base class. Will report back.

So I tried implementing the idea in your script linked above as follows:

class MyBaseModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.optimized_module = None

    def compile(self, *args, **kwargs):
        self._optimized_module = torch.compile(self, *args, **kwargs)

    def __call__(self, *args, **kwargs):
        if self._optimized_module:
            return self._optimized_module(*args, **kwargs)
        else:
            return super().__call__(*args, **kwargs)

    def __getstate__(self):
        state = self.__dict__.copy()
        # Remove _optimized_module from the state
        state.pop('_optimized_module', None)
        return state

However this produces a RecursionError: maximum recursion depth exceeded error when calling model.train() and probably other methods that iterate over child modules. This I guess makes sense since eventually it calls self._optimized_module.train() which falls back to the un-optimized module’s .train() and thus produces the infinite recursion.

The following, which is a subset of the PR you linked, prevents the recursion and seems to work:

class MyBaseModel(nn.Module):
    def __init__(self):
        super().__init__()
        self._compiled_call_impl = None

    def compile(self, *args, **kwargs):
        self._compiled_call_impl = torch.compile(self._call_impl, *args, **kwargs)

    def __call__(self, *args, **kwargs):
        if self._compiled_call_impl is not None:
            return self._compiled_call_impl(*args, **kwargs)
        else:
            return self._call_impl(*args, **kwargs)

    def __getstate__(self):
        state = self.__dict__.copy()
        # Remove _compiled_call_impl from the state
        state.pop('_compiled_call_impl', None)
        return state

Should I be wary of using internal variables like _call_impl? Also, is the rest of the code in the PR necessary for what I want to achieve?