Hi, I’m using AMP in my training process and validate the model with torch.no_grad periodically, say every 1000 updates. For the first 1000 updates, the program works correctly, showing fewer GPU memory consumption and faster speed compared to regular FP32 training. However, as soon as I started validation, the memory quickly blows up and throws an CUDA OOM error. On the other hand, FP32 training with exactly the same hyperparameters and training data does not OOM throught the whole training process.
@torch.no_grad() def validate(args, model, dev_itr): ... def train(): for i, batch in enumerate(train_itr): with torch.cuda.amp.autocast(True): loss=model(batch) # batch and model both fp32 loss.backward() optimizer.step() if i%val_every==0: with torch.cuda.amp.autocast(True): validate(args, model, dev_itr) # OOM
If I disable the
autocast right outside
validate(), the training proceeds successfully.
Any clue on the source of the problem would be helpful.