Ive written a custom loss function to measure the average circularity of channel segmentations in a batch.
def M(i,j,I):
'''calculates the i,jth moment of image I'''
x_grid,y_grid = torch.meshgrid(torch.arange(0,I.shape[0],dtype=torch.float,device = I.device,requires_grad=True),torch.arange(0,I.shape[1],dtype=torch.float,device = I.device,requires_grad=True))
x_grid = x_grid**i
y_grid = y_grid**j
moment = torch.sum(x_grid*y_grid*I)
return moment
class MomentCircScore(nn.Module):
def __init__(self, weight: torch.Tensor = None):
super(MomentCircScore, self).__init__()
# Sanity checks
assert weight is None or isinstance(weight, torch.Tensor), "Class Loss weight must be a tensor"
assert weight is None or weight.ndim == 1, "Class Loss weight must be a 1D list"
assert weight is None or weight.sum() == 1, "Class Loss weight must sum to 1"
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
assert weight is None or device == weight.device, "Loss weight on inconsistent device"
# Declare loss variables
self.weights = weight # get rd of background weight
self.pi = torch.tensor(math.pi)
self.device = device
def forward(self,y_pred):
# softmax pred
y_pred = F.softmax(y_pred, dim=1)
shape = y_pred.shape
# init scores tensor B x C
circ_scores = torch.zeros((shape[0],shape[1]),device =self.device,requires_grad=True,dtype=torch.float)
# label each pixel based on max channel value
with torch.no_grad():
labelled_y_pred = torch.argmax(y_pred,dim=1)
# calc circularity score for each channel
for batch in range(shape[0]):
for class_ in range(0,shape[1]): # not calculating background for computation speed up
with torch.no_grad():
# generate binary class seg
class_map = torch.where(labelled_y_pred[batch]==class_,1,0)
# calculate moments
M_0_0 = M(0,0,class_map)
M_1_0 = M(1,0,class_map)
M_0_1 = M(0,1,class_map)
x_bar = M_1_0/M_0_0
y_bar = M_0_1/M_0_0
#calculate circularity score for class
c_s = (1/(2*self.pi)) * (M_0_0**2/(M(2,0,class_map) + M(0,2,class_map) - y_bar*M_0_1 - x_bar*M_1_0))
if torch.isnan(c_s):
# if no pixels belong to class, set score to 0
c_s = torch.tensor(0,device=self.device,requires_grad=True,dtype=torch.float)
# add score to scores tensor
with torch.no_grad():
circ_scores[batch,class_]= c_s
# weight / average scores along channel
if self.weights is not None:
weighted_circ_scores = circ_scores * self.weights.repeat((shape[0], 1))
average_class_circ_scores = weighted_circ_scores.sum(1)
else:
average_class_circ_scores = circ_scores.mean(1)
return 1-average_class_circ_scores.mean()
Since the loss requires me to convert the model’s output to a binary format I need to use torch.argmax and torch.where, which both break gradients. As such, I put the calls to these functions within torch.no_grad() blocks.
I was hoping that the pytorch autograd engine would work its magic but during run time I’m getting the assertion error:
Traceback (most recent call last):
File "/home/imagingscience/Desktop/ia-vj/new_python_scripts/pytorch_scripts/UNet/train.py", line 457, in <module>
train(args)
File "/home/imagingscience/Desktop/ia-vj/new_python_scripts/pytorch_scripts/UNet/train.py", line 358, in train
scaler.step(optimiser)
File "/home/imagingscience/miniconda3/envs/aorta-seg/lib/python3.9/site-packages/torch/cuda/amp/grad_scaler.py", line 336, in step
assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."
AssertionError: No inf checks were recorded for this optimizer.
Although, this is after the .backward()
call in the code and so I am not sure if the issue is in the binarisation process or if it’s something else.
EDIT-------------------------------------
Error disappears when turning off mixed precision training, so calling scaler.step(optimiser)
ouside the with torch.cuda.amp.autocast
block.
Now I have 2 questions, why is this the fix, and why does my loss work when Ive got parts of it working in torch.no_grad()
blocks?
The relevant block from train.py
# zero parameter gradients
optimiser.zero_grad()
# forward + backward + optmise
with torch.cuda.amp.autocast(params['mp_training']):
outputs = net(inputs) # convert datatype to float32 for training
# calculate loss
loss = criterion(outputs) if params['loss_function']['loss'] == 'moment_circ_loss' else criterion(outputs,labels)
loss_value = loss.item()
if params['mp_training']:
scaler.scale(loss).backward()
# backpropagation
if params['mp_training']:
scaler.step(optimiser)
scaler.update()
else:
loss.backward()
optimiser.step()