Ive written a custom loss function to measure the average circularity of channel segmentations in a batch.
def M(i,j,I): '''calculates the i,jth moment of image I''' x_grid,y_grid = torch.meshgrid(torch.arange(0,I.shape,dtype=torch.float,device = I.device,requires_grad=True),torch.arange(0,I.shape,dtype=torch.float,device = I.device,requires_grad=True)) x_grid = x_grid**i y_grid = y_grid**j moment = torch.sum(x_grid*y_grid*I) return moment class MomentCircScore(nn.Module): def __init__(self, weight: torch.Tensor = None): super(MomentCircScore, self).__init__() # Sanity checks assert weight is None or isinstance(weight, torch.Tensor), "Class Loss weight must be a tensor" assert weight is None or weight.ndim == 1, "Class Loss weight must be a 1D list" assert weight is None or weight.sum() == 1, "Class Loss weight must sum to 1" device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') assert weight is None or device == weight.device, "Loss weight on inconsistent device" # Declare loss variables self.weights = weight # get rd of background weight self.pi = torch.tensor(math.pi) self.device = device def forward(self,y_pred): # softmax pred y_pred = F.softmax(y_pred, dim=1) shape = y_pred.shape # init scores tensor B x C circ_scores = torch.zeros((shape,shape),device =self.device,requires_grad=True,dtype=torch.float) # label each pixel based on max channel value with torch.no_grad(): labelled_y_pred = torch.argmax(y_pred,dim=1) # calc circularity score for each channel for batch in range(shape): for class_ in range(0,shape): # not calculating background for computation speed up with torch.no_grad(): # generate binary class seg class_map = torch.where(labelled_y_pred[batch]==class_,1,0) # calculate moments M_0_0 = M(0,0,class_map) M_1_0 = M(1,0,class_map) M_0_1 = M(0,1,class_map) x_bar = M_1_0/M_0_0 y_bar = M_0_1/M_0_0 #calculate circularity score for class c_s = (1/(2*self.pi)) * (M_0_0**2/(M(2,0,class_map) + M(0,2,class_map) - y_bar*M_0_1 - x_bar*M_1_0)) if torch.isnan(c_s): # if no pixels belong to class, set score to 0 c_s = torch.tensor(0,device=self.device,requires_grad=True,dtype=torch.float) # add score to scores tensor with torch.no_grad(): circ_scores[batch,class_]= c_s # weight / average scores along channel if self.weights is not None: weighted_circ_scores = circ_scores * self.weights.repeat((shape, 1)) average_class_circ_scores = weighted_circ_scores.sum(1) else: average_class_circ_scores = circ_scores.mean(1) return 1-average_class_circ_scores.mean()
Since the loss requires me to convert the model’s output to a binary format I need to use torch.argmax and torch.where, which both break gradients. As such, I put the calls to these functions within torch.no_grad() blocks.
I was hoping that the pytorch autograd engine would work its magic but during run time I’m getting the assertion error:
Traceback (most recent call last): File "/home/imagingscience/Desktop/ia-vj/new_python_scripts/pytorch_scripts/UNet/train.py", line 457, in <module> train(args) File "/home/imagingscience/Desktop/ia-vj/new_python_scripts/pytorch_scripts/UNet/train.py", line 358, in train scaler.step(optimiser) File "/home/imagingscience/miniconda3/envs/aorta-seg/lib/python3.9/site-packages/torch/cuda/amp/grad_scaler.py", line 336, in step assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer." AssertionError: No inf checks were recorded for this optimizer.
Although, this is after the
.backward() call in the code and so I am not sure if the issue is in the binarisation process or if it’s something else.
Error disappears when turning off mixed precision training, so calling
scaler.step(optimiser) ouside the
with torch.cuda.amp.autocast block.
Now I have 2 questions, why is this the fix, and why does my loss work when Ive got parts of it working in
The relevant block from
# zero parameter gradients optimiser.zero_grad() # forward + backward + optmise with torch.cuda.amp.autocast(params['mp_training']): outputs = net(inputs) # convert datatype to float32 for training # calculate loss loss = criterion(outputs) if params['loss_function']['loss'] == 'moment_circ_loss' else criterion(outputs,labels) loss_value = loss.item() if params['mp_training']: scaler.scale(loss).backward() # backpropagation if params['mp_training']: scaler.step(optimiser) scaler.update() else: loss.backward() optimiser.step()