How to alternate inference with modified model and training?

I’d like to alternate inference and training in a Resnet-type model, but I need to reconfigure it during inference, and I have some questions about that.

  1. Can I unload the model from the gpu by calling model.to(‘cpu’), make a modified copy (and run it on the gpu), and then move the original back to gpu by calling model.to(‘gpu’)? In other words, is moving the model gpu->cpu->gpu a lossless operation? (And what happens to the parameters that were passed to the optimizer? I don’t want to lose the optimizer state)

  2. What is the best way to make a copy of an in-memory model? I can save it and then reload a copy, but not sure if that is necessary just to copy.

  3. If I want to run inference in half precision (more than 2x faster in this case), can I change the model to half and then change it back? Is that lossless? (ie - does the model keep a copy of everything, or does it replace weights with half-precision in place?)

  1. I’m not sure what “lossless” means in this context. The optimizer stored references to the model parameters, so depending what your modification does, you might delete these references. Generally, you can move models between different devices (and this operation is also differentiable).

  2. I would recommend to create a deepcopy of the state_dict.

  3. Calling half() on the model might work, but could also create NaN/Inf outputs since FP16 can easily overflow. In particular during training we thus recommend to use automatic mixed-precision, but your inference use case might also work.
    All parameters will be transformed to HalfTensors and back to FloatTensor and will thus lose precision as seen in this example:

x = torch.randn(100)
y = x.half().float()
print((x - y).abs().max())

@ptrblck - Thanks! By lossless I mean a model can be round-tripped gpu-cpu-gpu and resume training as though the round trip did not take place - the weights are exactly the same, the optimizer state is the same, and nothing needs to be set up again. I’m not sure what you mean by moving between devices is differentiable, but it sounds like the optimizer should continue working after a round trip.

Putting together all of your suggestions - the train/inference/train loop could look something like this in simplified code:

# different models, but with compatible weights
model = ...
inference_model = ...
optimizer = optim.AdamW(model.parameters(), ...)

while True:

    model.to(device)
    model.train()
    
    for (data, target) in loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        logits = model(data)
        loss = loss_func(logits, target)
        loss.backward()
        optimizer.step()
        
    model.to('cpu')

    inference_model.load_state_dict(copy.deepcopy(model.state_dict())) # is deepcopy needed here?
    inference_model.half() # doesn't affect original model's weights
    inference_model.to(device)
    inference_model.eval()
    
    with torch.no_grad():
        for (data, target) in loader:
            data, target = data.to(device), target.to(device)
            logits = inference_model(data)

    inference_model.to('cpu') # maybe not necessary, just need to unload it, don't need a copy of the weights

Does that look about right?

Main things I’m not sure about
(a) Really don’t need to do anything with the optimizer, it would “just work”? I now realize in typical code the optimizer is created with model.parameters() before model.to(device), so it must work across moves :slight_smile: Are there any implications of doing a lot of moves vs just one?
(b) Copy state dict from one model to another like shown?

Thanks for the clarification.

  1. A stateless optimizer might work, but if it’s using internal buffers, such as e.g. Adam, you would get a device mismatch error. I would generally recommend to recreate the optimizer and load its state_dict. An annoyance is that you would need to push the internal states to the new device. Here is a code snippet adapted from this issue (a to() method for optimizers might be useful here :confused: ):
model = nn.Linear(1, 1).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Single update
out = model(torch.randn(1, 1).cuda())
out.mean().backward()
w0 = model.weight.clone()
optimizer.step()
w1 = model.weight.clone()
print((w1 - w0).abs().max())
optimizer.zero_grad()

# Move to CPU and verify update
sd = optimizer.state_dict()
model.cpu()
optimizer_cpu = torch.optim.Adam(model.parameters(), lr=1e-3)
optimizer.load_state_dict(sd)
for state in optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.cpu()
            
out = model(torch.randn(1, 1))
out.mean().backward()
w0 = model.weight.clone()
optimizer.step()
w1 = model.weight.clone()
print((w1 - w0).abs().max())
  1. Yes, the deepcopy seems to be fine here. The load_state_dict method would call param.copy_ as seen here, so it shouldn’t be necessary to create the deepcopy of the state_dict, but it might not hurt. :wink: