I am trying to train a DDP model (one GPU per process, but I’ve added the with `autocast(enabled=args.use_mp):`

to model forward just in case) with mixed precision using `torch.cuda.amp`

with train_bert function.

The model trains fine without `amp`

as well as with `autocast(enabled=False)`

. When I try running it with mixed precision (args.use_mp = True), I get nan loss after first iteration.

I used `autograd.detect_anomaly()`

to find that nan occurs in CrossEntropyLoss: RuntimeError: Function ‘LogSoftmaxBackward’ returned nan values in its 0th output. Not sure what kind of mistake am I looking for.

Below I insert the code for the training function and to the LossClass (criterion):

```
def train_bert(rank, args, epoch,
model, optimizer, criterion, loss_components,
train_generator, phase, scaler=None):
"""
bert training on args.bert_steps batches
:param rank: process rank
:param args: config class
:param epoch: epoch number for logging
:param model: bert model
:param optimizer: bert optimizer
:param criterion: callable loss
:param loss_components: torch.Tensor with cumulative losses for each objective
:param train_generator: train data loader (chunk wise)
:param phase: phase for phase wise training
:param scaler: scaler for mixed precision training (used if args.use_apex=True)
:return: mean step loss on the part of the training set
"""
model.train()
optimizer.zero_grad()
nb_tr_steps = phase * args.bert_steps
time_dict = {'step_time': 0.0, 'forward_time': 0.0, 'backward_time': 0.0}
for step, batch in islice(enumerate(tqdm(train_generator)), phase * args.bert_steps, (phase + 1) * args.bert_steps):
inp, lbl, meta = batch
X_ids, X_type, X_attn = (X.cuda(rank) for X in inp)
lbl = lbl.cuda(rank, non_blocking=True)
start = time.perf_counter()
with autocast(enabled=args.use_mp):
pred = model(X_ids, X_type, X_attn)
time_dict['forward_time'] += time.perf_counter() - start
components, loss = criterion(pred, lbl, args)
loss = loss / args.grad_acum_steps
start = time.perf_counter()
if args.use_mp:
scaler.scale(loss).backward()
else:
loss.backward()
time_dict['backward_time'] += time.perf_counter() - start
loss_components += components
nb_tr_steps += 1
if nb_tr_steps % args.grad_acum_steps == 0:
if args.use_mp:
scaler.step(optimizer)
scaler.update()
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
optimizer.step()
optimizer.zero_grad()
return loss_components.sum() / nb_tr_steps
```

```
class ConditionalLoss:
def __init__(self, args, rank):
self.bce = nn.BCEWithLogitsLoss().cuda(rank)
self.ce = nn.CrossEntropyLoss(reduction='none').cuda(rank)
self.loss_weights = args.loss_weights
def _binary_loss(self, pred, lbl):
return self.bce(pred, lbl.unsqueeze(1))
def _start_end_loss(self, pred, lbl, is_yes_no):
is_span = (1 - is_yes_no)
return (is_span * (self.ce(pred[0], lbl[0]) + self.ce(pred[1], lbl[1]))).mean()
def __call__(self, preds, labels, args):
pred_is_yes_no, pred_span = preds
is_yes_no, yes_no, span = unpack_lbls(labels)
s_w, yn_w, iyn_w = self.loss_weights
is_y_n_loss = iyn_w * self._binary_loss(pred_is_yes_no, is_yes_no)
span_loss = s_w * self._start_end_loss(pred_span, span, is_yes_no)
return torch.Tensor([span_loss.item(), is_y_n_loss.item()]), is_y_n_loss + span_loss
```