Can autocast context manager be used around all of training loop?

Often one has a method like train_one_epoch. Can I just use autocast around the call of this method? Or must we wrap only the call to the model after data iteration / move to device is done?

I would stick to the recommended usage of moving the model’s forward and loss calculation into the autocast context. These pieces from the docs might be interesting to check your use case:

autocast should wrap only the forward pass(es) of your network, including the loss computation(s). Backward passes under autocast are not recommended. Backward ops run in the same type that autocast used for corresponding forward ops.

Floating-point Tensors produced in an autocast-enabled region may be float16. After returning to an autocast-disabled region, using them with floating-point Tensors of different dtypes may cause type mismatch errors. If so, cast the Tensor(s) produced in the autocast region back to float32 (or other dtype if desired). If a Tensor from the autocast region is already float32, the cast is a no-op, and incurs no additional overhead.

Besides that autocast uses an internal cache to avoid re-casting constant tensors, which would never be freed if you are wrapping the entire code into an autocast region and you might thus have to either disable it or could suffer from increasing memory usage.

Does it cache both model parameters and model inputs? E.g. if only model parameters were cached, one could wrap the whole inference loop in autocast.

What I also don’t understand is at what moment tensors are cast to fp16? E.g. this is important for implementing forced fp32 module wrappers. E.g. there is an existing external module class that can’t be modified, and we’d like to wrap it / make a new module class that is always fp32 + inputs are not cast to fp16 if they are not coming from prior fp16 processing. E.g. there is a module ROIpool from wetectron/roi_pool.py at master · NVlabs/wetectron · GitHub. It is marked with apex.amp.float_function. I’d like to replace it with torchvision.ops.RoIPool, but how do I wrap ROIPooling module class such that rois argument never gets cast to fp16

The the Autocast and Custom Autograd Functions section use cases for custom methods are described. In your use case, since you don’t want to allow the inputs to be transformed to float16, you could disable autocast via a nested context manager:

In all cases, if you’re importing the function and can’t alter its definition, a safe fallback is to disable autocast and force execution in float32 ( or dtype) at any points of use where errors occur:

with autocast():
    ...
    with autocast(enabled=False):
        output = imported_function(input1.float(), input2.float())

I’m importing a module class, so I’d like to make a wrapping module class so that autocast wouldn’t cast unprocessed inputs to this module as fp16. Can I achieve that?

Pseudocode:

from torchvision.ops.misc import RoIPool as _RoIPool

class RoIPool(_RoiPool):
   def forward(self, x, rois)
     # how to decorate this function so that rois is never converted to fp16?
     return super().forward(x.float(), rois).type_as(x)

# MyModel processes x and does not process rois (pipes them directly to the module)
model = MyModel(RoiPool())

for x, rois in data_loader:
  with autocast(enabled = True):
    model(x, rois).backward()
  ...

Part of my problem is that I don’t understand how autocast works: at what point / how are casts done?

Disable autocast in the forward as given in the previous example:

class RoIPool(_RoiPool):
   def forward(self, x, rois):
        with autocast(enabled=False):
        out = super().forward(x.float(), rois) 
    return out

Casts are done internally during the dispatching, but I would recommend not to rely on the internal implementation as it also might depend on the backend (e.g. I don’t know how autocast works exactly on the CPU).
Here is a code snippet showing the input and output dtypes for layers, which can cast to float16 and others which are using float32 due to the needed numerical precision:

def hook(module, input, output):
    print(module)
    print([i.dtype for i in input])
    print([o.dtype for o in output])

lin1 = nn.Linear(10, 10).cuda()    
lin1.register_forward_hook(hook)

lin2 = nn.Linear(10, 10).cuda()    
lin2.register_forward_hook(hook)

sm = nn.Softmax(dim=1).cuda()
sm.register_forward_hook(hook)

x = torch.randn(1, 10).cuda()

with torch.cuda.amp.autocast():
    out = lin1(x)
    out = lin2(out)
    out = sm(out)

Thank you! At least having a way to debug/trace these casts is awesome! Maybe worth adding this bit to documentation!

About module class autocast decorators, module object decorators, function decorators: should we have something in standard library supporting it? Or is there already something? This way modularity/nestedness can go to more familiar class-level / object-level / function-level away from context-manager scopes level (and provides more information when inspecting the model in debugger) when it’s more convenient. This can be useful sometimes and support more functional code style.

Is there a way to trace all amp casts (of inputs, weights, activations)?

If you are looking into scripting a model with amp, then you could try to install the latest nightly, use the nvfuser backend and enable autocasting via torch._C._jit_set_autocast_mode(True) (which is a beta feature).

Oops, I meant trace in the sense of “log all amp casts” in order to understand at what points exactly the casts are happening

Should I be able to trace the model under autocast and then inspect script_module.code to view all the casts? Is the traced model + autocast behavior analogous to eager model + autocast during training mode?

Basically, I don’t understand at what point the casts are done (what means “during the dispatching”?), in what precision are activations preserved, how no-cast-needed exceptions are handled, when are model inputs casted, will model inputs still be downcasted if the first op in the model doesn’t support lower precision etc

There isn’t a “tracing debug mode” but the visualization of the computation graph should show the casts (you wouldn’t need to script the model to create it).

So instead of scripting, do tracing? can one run tracing inside of the autocast block? or should one put the autocast block inside the model to be traced? An example on this would probably be useful for the official docs/tutorial

A recent issue on torch.jit.trace + autocast: torch.jit.trace doesn't work with autocast on Conv node. · Issue #84092 · pytorch/pytorch · GitHub