Hi,
Is there an elegant way to write if-statement
, which determine whether to use the torch amp.
here is my naive implementation.
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
# Creates a GradScaler once at the beginning of training.
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
# Runs the forward pass with autocasting.
if use_amp:
with autocast():
output = model(input)
loss = loss_fn(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()