Model freeze after loading state

Hello, I’m trying to resume training after saving my model. I did this several times and worked. But now, no error is returned but the first batch is never processed.
Here a quick overview of my code, I tried to omit all boilerplate to make it more readable:

def train_model(RUN):
    def train(model, device, dataloader, loss_fn, optimizer, batch_acc):
        # metrics boilerplate
        for x, batch in enumerate(dataloader):
            print("DEBUG") # never prints
            # Train logic
        del data_prot, data_met, weights, target, logits
        return metrics

    ### Testing function

    def test(model, device, dataloader, loss_fn):
        # test boilerplate
    loss_fn = nn.BCEWithLogitsLoss
    model = IntegrativeModel( # Init boilerplate)

    params_to_optimize = [
        {'params': model.parameters() }
    optim = torch.optim.Adam(params_to_optimize, lr=0.000031, weight_decay= 0.0000072)
    path = RUN + "_last.pth"
    best_loss = 5
    if os.path.exists(path):
        torch.set_flush_denormal(True) # Tried this, nothing changed
        checkpoint = torch.load(RUN + "_last.pth")
        start = checkpoint["epoch"] + 1
        # Logging boilerplate
        print(f"Resuming training from epoch {start}") # This prints
        # Loading pretrained module boilerplate
        model.set_cold() # This function removes the gradient for the pretrained module
        start = 0
        train_log = {}
        with open(RUN+'.json', 'w') as f:
            json.dump(train_log, f)
    model= nn.DataParallel(model)
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    print(f'Selected device: {device}') # prints
    # Dataloader creation boilerplate
    num_epochs = 40
    for epoch in range(start, num_epochs):
        train_log[epoch] = {}
        if epoch == 1:
            model.module.set_warm() # I tried calling the function after loading
        elif epoch in [10, 15, 20, 25]:
            batch_acc *= 2
            for par in optim.param_groups:
                par["weight_decay"] = par["weight_decay"]/2
        print("starting epoch...") # prints
        train_loss = train(
            model, device, train_dataloader, loss_fn, optim, batch_acc)
        test_loss = test(model, device, test_dataloader_3, loss_fn)
        print('\n EPOCH {}/{} \t train loss {} \t val loss {}'.format(epoch
                                                                                + 1, num_epochs, train_loss, test_loss))
        # Metrics and logging boilerplate
        with open(RUN+'.json', 'w') as f:
            json.dump(train_log, f)
        if test_loss[1] > best_loss:
            best_loss = test_loss[0]
                    "optimizer": optim.state_dict(),
        best_loss = test_loss[0]{"epoch":epoch,
                    "optimizer": optim.state_dict(),

I don’t know what else I could try and my training is still stuck.

Based on:

        for x, batch in enumerate(dataloader):
            print("DEBUG") # never prints
            # Train logic

it seems that the DataLoader loop is never entered. Are you seeing these issues only when you are trying to restore the checkpoint?

Yes, only after reloading my model. If I’m not reloading my model I’ve tried different settings and it trains as expected.

Any idea about what could I try to debug this? I checked and it gets stuck exactly before looping over the DataLoader.

for x, batch in enumerate(dataloader):
            print("DEBUG") # never prints

both optimizer.zero_grad and model.train() get properly executed.

I’m seeing the same problem in my code. Crucially, the program seems to get stuck in a way that is unresponsive to a KeyBoardInterrupt and needs to be killed through the OS. Therefore there is no stack trace to identify the exact location where the problem occurs.

Problem seems to present in version v1.10.1 and v1.11.0 for me. I haven’t tested against other versions yet.

In my view, I see that this issue is only dependent on the dataloader. Maybe model saving and reloading are coincidences? I’m not sure. Can you add some print statements inside your dataset class to find out where the process is stuck?

I have tracked down the error in my code a little bit further and it seems that the dataloader is running fine on my side. the network inference, however, seems to get stuck at one point in the network.

From the looks of it, my problem seems different from OP’s, so I will open a seperate topic as soon as I have enough information.

I’ve managed to solve my issue. It seems Adam holds some tensors that are device-dependent (correct me if I’m wrong here), and the behavior is weird during loading. I provided a map location to the torch.load function, and I was careful to reference the model parameters on the right device and it solved the problem. No idea why this caused the DataLoader to freeze.

Probably it would be a good idea to add a to(device) method to the optimizer to allow explicit transfer of those tensors to the right device, since it’s not the first time I get trouble related to these tensors with Adam. I think they also tend to leave dangling pointers to GPU memory causing memory leaks.