Elegant way to determine whether use torch.amp

Hi,

Is there an elegant way to write if-statement, which determine whether to use the torch amp.

here is my naive implementation.

model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)

# Creates a GradScaler once at the beginning of training.
scaler = GradScaler()

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()

        # Runs the forward pass with autocasting.
        if use_amp:
            with autocast():
                output = model(input)
                loss = loss_fn(output, target)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            output = model(input)
            loss = loss_fn(output, target)
            
            loss.backward()
            optimizer.step()
             

The code looks correct and I’m not sure if there is a more elegant way.

Wouldn’t this work?

model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)

# Creates a GradScaler once at the beginning of training.
scaler = GradScaler(enabled=use_amp)

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()
        
        with autocast(enabled=use_amp):
            output = model(input)
            loss = loss_fn(output, target)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

2 Likes

Yes, your approach is the recommended way, so thanks for following up on this (I’ve missed to add this approach into this thread)!