Assert grad_scale is None and found_inf is None

Hi all,
Encountering an AssertionError from the above assertion in GradScaler.step( opt ). Have been using this same model sans mixed precision without issue for a while now, but I’m trying to implement mp to speed things up and running into this error. Finding just about 0 information on it via google, hence I’m posting here. Happy to provide whatever context is necessary, but don’t know what you guys might need to diagnose this up front, so will wait to be asked for additional information in lieu of guessing what’s relevant and dumping a bunch of unnecessary context here. Thanks in advance, this forum has been a huge help to the project I’m working on thus far (:

Based on the GradScaler usage you have posted it seems you are not creating an actual object but are calling step on the class directly? Could you post the training loop you are using?

Nice timing, I was actually just about to post the solution I found to this. My apologies for the unclear syntax though - GradScaler in my post above is indeed an object; I just so happened to stupidly name the variable the same as the class itself ((((((((((((((((((((:
Anyway, a quick look at Adam’s source ( /home/USERNAME/miniconda3/envs/ENVNAME/lib/python3.8/site-packages/torch/optim/adam.py being where the assertion error came from ) revealed this, starting on line 196 inside the function step:

if group['fused'] and grad_scaler is not None:
    grad_scale = grad_scaler._get_scale_async()
    device = grad_scale.device
    grad_scale = _MultiDeviceReplicator(grad_scale)
    found_inf = _get_fp16AMP_params(optimizer=self, grad_scaler=grad_scaler, device=device)

Noting if group['fused'], I went and removed fused=True from my optimizer’s constructor arguments and now everything works fine. I only discovered the foreach and fused args a couple hours ago, and I understand what foreach does, but I still really have no idea what fused does. Setting them both to True gave me about a 25%-33% speedup in training, but removing fused ( which supposedly sets this back to False ) does not seem to have slowed down my training at all, so all that speedup must’ve come from foreach=True, and now the assertion error is gone.

Like I said, since I have no real idea what fused=True does, my insight as to why it should cause the above assertion to fail is nil. Perhaps I just missed something about the interaction of this argument with amp in the docs somewhere. Perhaps you have some idea why this might be the case?

No, I don’t know why the fused path of the optimizer should cause the issue.
Would it be possible to post a minimal and executable code snippet to reproduce the issue?

Uhhh…I would love to but I suddenly cannot reproduce the issue myself. I’m sitting here looking at the code I’ve got and I swear it’s exactly the same as what I had yesterday, but clearly something’s changed. Suppose I’ll update this if the issue reappears, but guess this one’s “case closed” for the time being.
Don’t code under the influence of sleep deprivation, kids.