How to use torch.compile with load_state_dict?

I’m new to PyTorch. So I’m trying to figure out how use torch.compile in code I haven’t developed.

I have this for example:

    net = UNet(n_channels=1, n_classes=1)
    net.load_state_dict(torch.load(weights[feature], map_location=device))

So I tried:

    net = torch.compile(UNet(n_channels=1, n_classes=1))
    net.load_state_dict(torch.load(weights[feature], map_location=device))

But execution will fail with:

RuntimeError: Error(s) in loading state_dict for OptimizedModule

Inside my class UNet(nn.Module) I have instances for several other classes using nn.Module, so I thought I could use torch.compile there as well.

I’ve seen the new tutorials in PyTorch website but nothing specific to my case yet.

Could you try to load the state_dict first and torch.compile the model afterwards or do you have a specific requirement to compile the model beforehand?

1 Like

I’m not sure I would do it. Here’s the main part of the code using instance net:

from rcia_tools.helpers.for_torch import UNet, predict_bscans, torch

net = UNet(n_channels=1, n_classes=1)
net.load_state_dict(torch.load(weights[feature], map_location=device))

for group_id in base.groups:
    for idx in range(base.size[group_id][1]):
        mask_probabilities = predict_bscans(net, np.array(bscans), device, batch_num=n_size)


class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512, bilinear)
        self.up2 = Up(512, 256, bilinear)
        self.up3 = Up(256, 128, bilinear)
        self.up4 = Up(128, 64 * factor, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 =
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

class Up(nn.Module):
    """Upscaling then double conv"""

class OutConv(nn.Module):

def predict_bscans(unet, bscans, device, batch_num=2):
    """Tiles and Segments Bscan"""

This is essentially all the PyTorch part in my code. I'm just wondering how I could apply `torch.compile()` to check if I can benefit from it. Or any other particular optimisations.

Call model = torch.compile(model) after and check if you would see an improvement in performance.


I added that
net = torch.compile(net)

And that worked. However, I can’t really say if it got faster. I don’t have a proper benchmark and I’m using a busy platform (shared with other colleagues). That said, I did a few runs with and without compile and It seems to save like 10 to 30s: ~120s (compiled), ~150s (not compiled). But it may saving more on the long run.

BTW, should torch.compile() make a difference for CPU only cases? I tried her but saw some inconclusive messages like:

No CUDA runtime is found, using CUDA_HOME='/usr'

yet the code worked. I didn’t properly benchmark it, but my first impression it didn’t make any difference.

There’s some nuance with how to benchmark PT 2.0, you can take a look at what they are here mlsys-experiments/ at main · msaroufim/mlsys-experiments · GitHub - planning on upstreaming that script this week to core

CPU should still see speedups but the type of CPU might have an impact, the most dramatic speedups will be on GPU with tensor cores