No inf checks were recorded for this optimizer error for custom loss

Ive written a custom loss function to measure the average circularity of channel segmentations in a batch.

def M(i,j,I):
    '''calculates the i,jth moment of image I'''
    x_grid,y_grid = torch.meshgrid(torch.arange(0,I.shape[0],dtype=torch.float,device = I.device,requires_grad=True),torch.arange(0,I.shape[1],dtype=torch.float,device = I.device,requires_grad=True))
    x_grid = x_grid**i 
    y_grid = y_grid**j
    moment = torch.sum(x_grid*y_grid*I)

    return moment

class MomentCircScore(nn.Module):
    def __init__(self, weight: torch.Tensor = None):
        super(MomentCircScore, self).__init__()

        # Sanity checks
        assert weight is None or isinstance(weight, torch.Tensor), "Class Loss weight must be a tensor"
        assert weight is None or weight.ndim == 1, "Class Loss weight must be a 1D list"
        assert weight is None or weight.sum() == 1, "Class Loss weight must sum to 1"

        device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
        assert weight is None or device == weight.device, "Loss weight on inconsistent device"

        # Declare loss variables
        self.weights = weight # get rd of background weight 
        self.pi = torch.tensor(math.pi)
        self.device = device

    def forward(self,y_pred):
        # softmax pred
        y_pred = F.softmax(y_pred, dim=1)
        shape = y_pred.shape
        # init scores tensor B x C
        circ_scores = torch.zeros((shape[0],shape[1]),device =self.device,requires_grad=True,dtype=torch.float)
        # label each pixel based on max channel value
        with torch.no_grad():
            labelled_y_pred = torch.argmax(y_pred,dim=1)
        
        # calc circularity score for each channel
        for batch in range(shape[0]):
            for class_ in range(0,shape[1]): # not calculating background for computation speed up
                with torch.no_grad():
                    # generate binary class seg 
                    class_map = torch.where(labelled_y_pred[batch]==class_,1,0)
                     
                # calculate moments
                M_0_0 = M(0,0,class_map)
                M_1_0 = M(1,0,class_map)
                M_0_1 = M(0,1,class_map)

                x_bar = M_1_0/M_0_0
                y_bar = M_0_1/M_0_0
                
                #calculate circularity score for class
                c_s = (1/(2*self.pi)) * (M_0_0**2/(M(2,0,class_map) + M(0,2,class_map) - y_bar*M_0_1 - x_bar*M_1_0))
                if torch.isnan(c_s):
                    # if no pixels belong to class, set score to 0
                    c_s = torch.tensor(0,device=self.device,requires_grad=True,dtype=torch.float)
                
                # add score to scores tensor
                with torch.no_grad():
                    circ_scores[batch,class_]= c_s
        
        # weight / average scores along channel
        if self.weights is not None:
            weighted_circ_scores = circ_scores * self.weights.repeat((shape[0], 1))
            average_class_circ_scores = weighted_circ_scores.sum(1)
        
        else:
            average_class_circ_scores = circ_scores.mean(1)
        
        return 1-average_class_circ_scores.mean()  

Since the loss requires me to convert the model’s output to a binary format I need to use torch.argmax and torch.where, which both break gradients. As such, I put the calls to these functions within torch.no_grad() blocks.

I was hoping that the pytorch autograd engine would work its magic but during run time I’m getting the assertion error:

Traceback (most recent call last):
  File "/home/imagingscience/Desktop/ia-vj/new_python_scripts/pytorch_scripts/UNet/train.py", line 457, in <module>
    train(args)
  File "/home/imagingscience/Desktop/ia-vj/new_python_scripts/pytorch_scripts/UNet/train.py", line 358, in train
    scaler.step(optimiser)
  File "/home/imagingscience/miniconda3/envs/aorta-seg/lib/python3.9/site-packages/torch/cuda/amp/grad_scaler.py", line 336, in step
    assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."
AssertionError: No inf checks were recorded for this optimizer.

Although, this is after the .backward() call in the code and so I am not sure if the issue is in the binarisation process or if it’s something else.

EDIT-------------------------------------

Error disappears when turning off mixed precision training, so calling scaler.step(optimiser) ouside the with torch.cuda.amp.autocast block.

Now I have 2 questions, why is this the fix, and why does my loss work when Ive got parts of it working in torch.no_grad() blocks?

The relevant block from train.py

                # zero parameter gradients
                optimiser.zero_grad()
                # forward + backward + optmise
                with torch.cuda.amp.autocast(params['mp_training']):
                    outputs = net(inputs) # convert datatype to float32 for training
                    # calculate loss
                    loss = criterion(outputs) if params['loss_function']['loss'] == 'moment_circ_loss' else criterion(outputs,labels) 
                    loss_value = loss.item()

                    if params['mp_training']:
                        scaler.scale(loss).backward()

                # backpropagation
                if params['mp_training']:
                    scaler.step(optimiser)
                    scaler.update()
                else:
                    loss.backward()
                    optimiser.step()

I’m not sure, if you’ve fixed the issue or if the model is just not updated, since (as you’ve already explained) argmax would break the computation graph.
You could check the .grad attribute of all parameters of the model after the backward call and see if they are containing valid gradients.