AMP out of memory during validation

Hi, I’m using AMP in my training process and validate the model with torch.no_grad periodically, say every 1000 updates. For the first 1000 updates, the program works correctly, showing fewer GPU memory consumption and faster speed compared to regular FP32 training. However, as soon as I started validation, the memory quickly blows up and throws an CUDA OOM error. On the other hand, FP32 training with exactly the same hyperparameters and training data does not OOM throught the whole training process.

@torch.no_grad()
def validate(args, model, dev_itr):
    ...
def train():
    for i, batch in enumerate(train_itr):
        with torch.cuda.amp.autocast(True):
            loss=model(batch) # batch and model both fp32
            loss.backward()
            optimizer.step()
        if i%val_every==0:
            with torch.cuda.amp.autocast(True):
               validate(args, model, dev_itr) # OOM

If I disable the autocast right outside validate(), the training proceeds successfully.
Any clue on the source of the problem would be helpful.

1 Like

Could you post an executable code snippet to reproduce this issue?
Also, which GPU, CUDA, and cudnn versions are you using?