Okay, so I am new to PyTorch, and the concept of autograd is still new to me. I am trying to implement a loss function that also considers the distance between two boundaries, which is computed using ray casting. However, after the dice loss is done, the distance loss fails to update. The gradients become none, and the training loss is a constant number, i.e 4.3423. Are my gradients not flowing backward? Any help will be appreciated. After the switch from 20 to the geometric loss, the grads are none. I checked it using. As long as the dice is computed, the grads are not None.
import torch
import torch.nn as nn
import kornia.morphology as morph
class VectorizedRayCaster(nn.Module):
def __init__(self, origins, angles, step_size=1.0):
"""
origins: Tensor of shape (N, 2) for N rays (each origin is [x, y])
angles: Tensor of shape (N,) with angles in degrees for each ray.
step_size: Scalar step to move along the ray.
"""
super().__init__()
self.origins = origins # Expected shape: (N, 2)
self.angles = angles # Expected shape: (N,)
self.step_size = step_size
def cast_ray(self, mask, mask_value_threshold=0.5, max_steps=75, device='cuda'):
"""
Vectorized ray casting.
mask: a 2D tensor with shape (H, W) representing the boundary mask.
Returns: intersections of shape (N, 2) for N rays.
"""
# Ensure both directions and origins are on the same device as mask
device = mask.device # Get the device of the mask
origins = self.origins.to(device)
angles = self.angles.to(device)
N = self.origins.shape[0]
# Convert angles to radians and compute ray directions (N, 2)
theta_rad = torch.deg2rad(self.angles) # (N,)
directions = torch.stack([torch.cos(theta_rad), torch.sin(theta_rad)], dim=-1).to(device) # (N, 2)
# Initialize current positions with the origins (N, 2)
current_pos = self.origins.clone().to(device)
# Prepare an output tensor for intersections, fill with NaNs initially
intersections = torch.full_like(current_pos, float('nan'))
hit_mask = torch.zeros(N, dtype=torch.bool, device=mask.device)
H, W = mask.shape # mask is assumed to be 2D
# Iterate for a maximum number of steps
for step in range(max_steps):
# Update positions of all rays (only those that haven't hit yet)
current_pos = current_pos + self.step_size * directions
# Convert current positions to integer indices (using long())
x_idx = current_pos[:, 0].long() # (N,)
y_idx = current_pos[:, 1].long() # (N,)
# Check boundaries for all rays
valid = (x_idx >= 0) & (x_idx < W) & (y_idx >= 0) & (y_idx < H)
if valid.sum() == 0:
# If no rays are valid, break out of the loop.
break
# For valid rays, index the mask to get mask values
mask_values = mask[y_idx[valid], x_idx[valid]]
# Determine which valid rays hit the threshold
hits = mask_values >= mask_value_threshold
# Get global indices of valid rays. Using view(-1) ensures a 1D vector.
valid_indices = valid.nonzero(as_tuple=False).view(-1)
# For rays that hit, update intersections if not already updated
for idx, hit_val in zip(valid_indices, hits):
# Ensure idx is an integer
idx = idx.item() if isinstance(idx, torch.Tensor) else idx
if hit_val and not hit_mask[idx]:
intersections[idx] = current_pos[idx]
hit_mask[idx] = True
# If all rays have hit, we can exit early
if hit_mask.all():
break
# For rays that never hit, assign a default value (e.g., [0.0, 0.0])
intersections[~hit_mask] = 0.0
return intersections
enter code here
def find_centeroid(epi, endo, threshold=0.5, default=[128.0, 128.0]):
B, C, H, W = epi.shape
centroids = []
for b in range(B):
center_epi = torch.argwhere(epi[b][0] > threshold).float()
center_epi = torch.mean(center_epi, axis=0)
center_endo = torch.argwhere(endo[b][0] > threshold).float()
center_endo = torch.mean(center_endo, axis=0)
if torch.isnan(center_epi).any() and not torch.isnan(center_endo).any():
center_epi = torch.mean(center_endo, axis=0)
if not torch.isnan(center_epi).any() and torch.isnan(center_endo).any():
center_endo = torch.mean(center_epi, axis=0)
if torch.isnan(center_epi).any() and torch.isnan(center_endo).any():
center_epi, center_endo = torch.tensor(default, dtype=torch.float), torch.tensor(default, dtype=torch.float)
center = (center_epi + center_endo) / 2.0
centroids.append(center.flip(0)) # flip x,y instead of y,x
return centroids
def get_boundary(mask, method='erosion'):
mask = mask.float()
if method == 'erosion':
# Define a 2D kernel for morphological operation.
kernel = torch.ones(3, 3, device=mask.device, dtype=mask.dtype)
# Perform erosion using Kornia. Canny edge filter can also be used. The boundaries have holes
eroded = morph.erosion(mask, kernel)
# Kornia is differentiable
boundary = mask - eroded
# fill boundary gaps
closed_boundary = morph.closing(boundary, kernel)
elif method == 'canny':
closed_boundary = kornia.filters.canny(mask, low_threshold=0.1,
high_threshold=0.5, kernel_size=(5, 5),
sigma=(1, 1))
return closed_boundary
enter code here
from adaptive_dice_loss import AdaptiveDiceLoss
model_distance = unet_zhao.UNet(2)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dice_loss = AdaptiveDiceLoss()
# Move model to device
model_distance = model_distance.to(device)
# Adam optimizer MSE
optimizer = Adam(model_distance.parameters(), lr=0.001)
criterion = nn.L1Loss(reduction='mean')
# Number of epochs to train
num_epochs = 25
# Lists to store the losses for plotting or monitoring
train_losses, val_losses = [], []
angles = torch.arange(0, 360, 30, device=device)
for epoch in range(num_epochs):
# Training phase
model_distance.train() # Set the model to training mode
train_loss = 0.0
for batch_idx, (img, label, gt_dist) in enumerate(tqdm(train_loader_dist, desc=f"Epoch {epoch+1}/{num_epochs} - Training")):
img, label = img.to(device, dtype=torch.float), label.to(device, dtype=torch.float) # Move data to device
gt_dist = gt_dist.to(device, dtype=torch.float)
# Convert vectors to norm to find the distances
epi_mask, endo_mask, gt_dist = label[:, 0, :, :], label[:, 1, :, :], gt_dist
optimizer.zero_grad() # Zero gradients
# Forward pass
pred = model_distance(img)
epi_mask_pred, endo_mask_pred = pred[:, 0, :, :], pred[:, 1, :, :]
# loss function
dice_ = dice_loss(pred, label, current_epoch=epoch, total_epochs=20)
if epoch <= 20:
loss = dice_
else:
gt_dist = torch.linalg.norm(gt_dist, dim=2)
epi_boundary, endo_boundary = get_boundary(epi_mask_pred.unsqueeze(1)), get_boundary(endo_mask_pred.unsqueeze(1))
centroid = find_centeroid(epi_boundary, endo_boundary)
diff_norms = []
for b in range(0, epi_boundary.shape[0]):
ray_origins = centroid[b].unsqueeze(0).expand(angles.shape[0], -1) # Shape: (N, 2)
# Now create the ray caster with consistent dimensions
epi_ray = VectorizedRayCaster(ray_origins, angles, 1.0)
intersection_epi = epi_ray.cast_ray(epi_boundary[b][0], mask_value_threshold=0.5, max_steps=75).to(device)
endo_ray = VectorizedRayCaster(ray_origins, angles, 1.0)
intersection_endo = endo_ray.cast_ray(endo_boundary[b][0], mask_value_threshold=0.5, max_steps=75).to(device)
diff = (intersection_epi - intersection_endo)
norm = torch.linalg.norm(diff, dim=1)
diff_norms.append(norm)
pred_dist = torch.stack(diff_norms).requires_grad_()
print("Gt and pred shapes", gt_dist.shape, pred_dist.shape)
gt_dist = gt_dist.to(device)
pred_dist = pred_dist.to(device)
loss = criterion(gt_dist, pred_dist)
# Backward pass and optimization
loss.backward()
optimizer.step()
# Accumulate training loss
train_loss += loss.item()
# Average training loss for the epoch
avg_train_loss = train_loss / len(train_loader)
train_losses.append(avg_train_loss)
print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {avg_train_loss:.4f}")