U-Net training with mixed precision fails with cublasGemmEx RuntimeError

We are trying to train a U-Net based network for box classification using cross entropy loss. We are using new NVIDIA APEX mixed precision library and having the following issue during training. After couple of epochs (depending on the APEX apm optimization level) the training stops due to the following error message:

...
  File "/home/pavel/.local/lib/python3.7/site-packages/torch/tensor.py", line 150, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/pavel/.local/lib/python3.7/site-packages/torch/autograd/__init__.py", line 99, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasGemmEx( handle, opa, opb, m, n, k, &falpha, a, CUDA_R_16F, lda, b, CUDA_R_16F, ldb, &fbeta, c, CUDA_R_16F, ldc, CUDA_R_32F, CUBLAS_GEMM_DFALT_TENSOR_OP)`

Could there be a bug in backward which is triggered by this particular network because resnet18 trains fine with apex using any optimization level.
Does anyone have a clue of what could be the reason?

Here is where the error is getting thrown:

        for epoch in range(1, self.num_epochs + 1):
            logger.info(f"running epoch {epoch}")
            avg_train_loss = 0

            self.model.train()
            for step, sample_batch in enumerate(self.train_data, start=1):
                sample_batch = self._sample_to_device(sample_batch)
                self.optimizer.zero_grad()

                doc_id_batch = sample_batch[DOC_ID]

                logits_dict = self.model(sample_batch)
                loss = self.criterion(logits_dict, sample_batch)
                with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()  # <<< exception is thrown

                self.optimizer.step()

                avg_train_loss += loss.item()

            epoch_end_time = timeit.default_timer()
            epoch_time = epoch_end_time - epoch_start_time

The original issue discussion and logs are here: https://github.com/NVIDIA/apex/issues/580#issuecomment-549487887

I’ve commented in the GitHub issue.
Your error might be related to this issue, if your input size is too large.