Torch.cuda.amp equivalent of apex.amp.initialize?

I’m working on a project forked from a HuggingFace code base that used NVIDIA’s apex.amp.initialize.

            try:
                from apex import amp
            except ImportError:
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
            self.student, self.optimizer = amp.initialize(
                self.student,
                self.optimizer,
                opt_level=self.run_config['devices']['fp16_opt_level']
            )

What’s the equivalent function call using torch.cuda.amp?

@mcarilli, can I ask you for your insight?

Also, follow up question: apex had levels e.g. ‘fp16_opt_level’: ‘O1’, # For fp16: Apex AMP optimization level selected in [‘O0’, ‘O1’, ‘O2’, and ‘O3’]. Does torch.cuda.amp have anything similar I need to keep in mind?

You don’t need to initialize the model and optimizer and should use torch.cuda.amp.autocast as well as torch.cuda.amp.GradScaler as described here.

No, right now native amp is similar to apex.amp O1. We are experimenting with an O2-style mode, which is still WIP.

1 Like

The two ingredients of native amp (torch.cuda.amp.autocast and torch.cuda.amp.GradScaler) do not affect the model or optimizer in a stateful way.

autocast casts inputs to listed functions on the fly. The casted inputs are temporaries. They may end up stashed for backward, but that’s the longest they last. They never overwrite any model attributes.

GradScaler scales the loss and (either in an explicit call to scaler.unscale_(optimizer) or implicitly inside scaler.step(optimizer) unscales .grad attributes of the optimizer’s params in place, but it doesn’t overwrite or replace any optimizer attributes.

This is by design, to avoid the black box monkey-patching spaghetti that can take place in apex.amp.initialize. torch.cuda.amp has no equivalent to initialize, so just delete the line.

To toggle amp on or off based on run_config without needing to write divergent code, use autocast and GradScaler’s enabled= argument as shown here. There’s no equivalent of opt_levels with native amp, it’s either on or off, so you’ll have to decide how to interpret or change fp16_opt_level.

1 Like