Save and load a mixed precision model

I am training distilBert model for text classification. I am using the amp package to train the mixed precision version. here is the code to train the model

     # define the model
      model = BertMulticlassifier(....)

    # Define optimizer and scheduler
    optimizer = Config.OPTIMIZER(model.parameters(), lr=Config.LR, bias_correction=False)
    # model, optimizer = amp.initialize(model, optimizer, opt_level="O2")
    scheduler = Config.SCHEDULER(optimizer=optimizer,
                                 num_warmup_steps=400,
                                 num_training_steps=len(train_data_loader) * Config.EPOCHS
                                 )

    model.compile(optimizer=optimizer, loss=loss_func, metrics=metric, scheduler=scheduler)
    model.fit(train_loader=train_data_loader, val_loader=val_data_loader, epochs=Config.EPOCHS)
   # save the model
    torch.save(self.state_dict(), model_name)

Load the model and run it on CPU:

model = BertMulticlassifier(....)
model.load_state_dict(torch.load('model.bin', map_location='cpu'))
model.eval()
model.predict()

By comparing the inference time between the distilbert and the mixed precision distilbert version, it was more or less the same, which does not make sense. I expect that the mixed precision version take less time.

So, I am wondering that the amp just used the mixed precision to make the train and then it saved the model with float32 or the problem is coming from the load function ? What is the right way to load the model in the mixed precision version. Thanks to take a look @ptrblck

It seems you are using a higher-level API, which provides a fit and predict function, and which might use automatic mixed precision internally.
Which library are you using and how did you specify to use amp?

Thanks @ptrblck. Basically, I implemented my own class following the keras schema, using the transformers hugging face package
here is the internal part of the code in fit() function for one epoch

losses = []
predictions = []
 for step, batch in  enumerate(data_loader):
            input_ids = batch["input_ids"].to(self.device)
            attention_mask = batch["attention_mask"].to(self.device)
            targets = batch["targets"].to(self.device)
            outputs = self.__call__(
                input_ids=input_ids,
                attention_mask=attention_mask
            )
            loss = self.loss_func(outputs, targets)
           ......
           with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                     scaled_loss.backward()
            nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)
            self.optimizer.step()
            self.scheduler.step()
            self.optimizer.zero_grad()
return np.mean(predictions), np.mean(losses)

Alos pytorch provide an other way to quantize the model. I have used this with bertdistill trained with mixed precision. The quantized model performs as good as the float32 version!

So my though is that the mixed precision is used just for the training phase and the model is saved all as 32 float??

model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)

Yes, the parameters and buffers would be stored in float32. If you don’t use autocast during inference, then you would execute the standard FP32 model.

Note that mixed precision utilities are now usable in PyTorch directly without apex, so I would recommend to update the code to torch.cuda.amp.

Thanks so much @ptrblck. I have changed my code using autocast for torch.cuda.amp. However I remark that before with

model, optimizer = amp.initialize(model, optimizer, opt_level="O2")

I can use a batch_size of 16, however using the autocast, I could not use more than 8 !!! It seems some operations are not casted to float16 as in apex! I think the autocast implement the level 01 of apex

The native amp implementation is corresponding to opt_level="O1" in apex.amp, as the usability is the most user-friendly and it should support all use cases. We are experimenting with other utilities to lower the memory usage, but for now we strongly recommend to stick to the native implementation for the best support.

2 Likes