Pytorch autograd and backpropagation

Okay, so I am new to PyTorch, and the concept of autograd is still new to me. I am trying to implement a loss function that also considers the distance between two boundaries, which is computed using ray casting. However, after the dice loss is done, the distance loss fails to update. The gradients become none, and the training loss is a constant number, i.e 4.3423. Are my gradients not flowing backward? Any help will be appreciated. After the switch from 20 to the geometric loss, the grads are none. I checked it using. As long as the dice is computed, the grads are not None.

import torch
import torch.nn as nn
import kornia.morphology as morph

class VectorizedRayCaster(nn.Module):
def __init__(self, origins, angles, step_size=1.0):
    """
    origins: Tensor of shape (N, 2) for N rays (each origin is [x, y])
    angles: Tensor of shape (N,) with angles in degrees for each ray.
    step_size: Scalar step to move along the ray.
    """
    super().__init__()
    self.origins = origins  # Expected shape: (N, 2)
    self.angles = angles    # Expected shape: (N,)
    self.step_size = step_size

def cast_ray(self, mask, mask_value_threshold=0.5, max_steps=75, device='cuda'):
    """
    Vectorized ray casting.
    mask: a 2D tensor with shape (H, W) representing the boundary mask.
    Returns: intersections of shape (N, 2) for N rays.
    """
    # Ensure both directions and origins are on the same device as mask
    device = mask.device  # Get the device of the mask
    origins = self.origins.to(device)
    angles = self.angles.to(device)

    N = self.origins.shape[0]
    # Convert angles to radians and compute ray directions (N, 2)
    theta_rad = torch.deg2rad(self.angles)  # (N,)
    directions = torch.stack([torch.cos(theta_rad), torch.sin(theta_rad)], dim=-1).to(device)  # (N, 2)

    # Initialize current positions with the origins (N, 2)
    current_pos = self.origins.clone().to(device)

    # Prepare an output tensor for intersections, fill with NaNs initially
    intersections = torch.full_like(current_pos, float('nan'))
    hit_mask = torch.zeros(N, dtype=torch.bool, device=mask.device)

    H, W = mask.shape  # mask is assumed to be 2D

    # Iterate for a maximum number of steps
    for step in range(max_steps):
        # Update positions of all rays (only those that haven't hit yet)
        current_pos = current_pos + self.step_size * directions

        # Convert current positions to integer indices (using long())
        x_idx = current_pos[:, 0].long()  # (N,)
        y_idx = current_pos[:, 1].long()  # (N,)

        # Check boundaries for all rays
        valid = (x_idx >= 0) & (x_idx < W) & (y_idx >= 0) & (y_idx < H)

        if valid.sum() == 0:
            # If no rays are valid, break out of the loop.
            break

        # For valid rays, index the mask to get mask values
        mask_values = mask[y_idx[valid], x_idx[valid]]

        # Determine which valid rays hit the threshold
        hits = mask_values >= mask_value_threshold

        # Get global indices of valid rays. Using view(-1) ensures a 1D vector.
        valid_indices = valid.nonzero(as_tuple=False).view(-1)

        # For rays that hit, update intersections if not already updated
        for idx, hit_val in zip(valid_indices, hits):
            # Ensure idx is an integer
            idx = idx.item() if isinstance(idx, torch.Tensor) else idx
            if hit_val and not hit_mask[idx]:
                intersections[idx] = current_pos[idx]
                hit_mask[idx] = True

        # If all rays have hit, we can exit early
        if hit_mask.all():
            break

    # For rays that never hit, assign a default value (e.g., [0.0, 0.0])
    intersections[~hit_mask] = 0.0

    return intersections

enter code here
def find_centeroid(epi, endo, threshold=0.5, default=[128.0, 128.0]):
B, C, H, W = epi.shape
centroids = []

for b in range(B):
    center_epi = torch.argwhere(epi[b][0] > threshold).float()
    center_epi = torch.mean(center_epi, axis=0)

    center_endo = torch.argwhere(endo[b][0] > threshold).float()
    center_endo = torch.mean(center_endo, axis=0)

    if torch.isnan(center_epi).any() and not torch.isnan(center_endo).any():
        center_epi = torch.mean(center_endo, axis=0)

    if not torch.isnan(center_epi).any() and torch.isnan(center_endo).any():
        center_endo = torch.mean(center_epi, axis=0)

    if torch.isnan(center_epi).any() and torch.isnan(center_endo).any():
        center_epi, center_endo = torch.tensor(default, dtype=torch.float), torch.tensor(default, dtype=torch.float)

    center = (center_epi + center_endo) / 2.0

    centroids.append(center.flip(0))  # flip x,y instead of y,x

return centroids


def get_boundary(mask, method='erosion'):
mask = mask.float()

if method == 'erosion':
    # Define a 2D kernel for morphological operation. 
    kernel = torch.ones(3, 3, device=mask.device, dtype=mask.dtype)

    # Perform erosion using Kornia. Canny edge filter can also be used. The boundaries have holes
    eroded = morph.erosion(mask, kernel)

    # Kornia is differentiable
    boundary = mask - eroded

    # fill boundary gaps
    closed_boundary = morph.closing(boundary, kernel)

elif method == 'canny':
    closed_boundary = kornia.filters.canny(mask, low_threshold=0.1,
                                           high_threshold=0.5, kernel_size=(5, 5),
                                           sigma=(1, 1))

return closed_boundary

enter code here
from adaptive_dice_loss import AdaptiveDiceLoss


model_distance = unet_zhao.UNet(2)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dice_loss = AdaptiveDiceLoss()

# Move model to device
model_distance = model_distance.to(device)

#  Adam optimizer MSE
optimizer = Adam(model_distance.parameters(), lr=0.001)
criterion = nn.L1Loss(reduction='mean')

# Number of epochs to train
num_epochs = 25

# Lists to store the losses for plotting or monitoring
train_losses, val_losses = [], []
angles = torch.arange(0, 360, 30, device=device)

for epoch in range(num_epochs):
# Training phase
model_distance.train()  # Set the model to training mode
train_loss = 0.0

for batch_idx, (img, label, gt_dist) in enumerate(tqdm(train_loader_dist, desc=f"Epoch {epoch+1}/{num_epochs} - Training")):
    img, label = img.to(device, dtype=torch.float), label.to(device, dtype=torch.float)  # Move data to device
    gt_dist = gt_dist.to(device, dtype=torch.float)

    # Convert vectors to norm to find the distances
    epi_mask, endo_mask, gt_dist = label[:, 0, :, :], label[:, 1, :, :], gt_dist

    optimizer.zero_grad()  # Zero gradients

    # Forward pass
    pred = model_distance(img)

    epi_mask_pred, endo_mask_pred = pred[:, 0, :, :], pred[:, 1, :, :]

    # loss function
    dice_ = dice_loss(pred, label, current_epoch=epoch, total_epochs=20)

    if epoch <= 20:
        loss = dice_

    else:
        gt_dist = torch.linalg.norm(gt_dist, dim=2)

        epi_boundary, endo_boundary = get_boundary(epi_mask_pred.unsqueeze(1)), get_boundary(endo_mask_pred.unsqueeze(1))
        centroid = find_centeroid(epi_boundary, endo_boundary)

        diff_norms = []

        for b in range(0, epi_boundary.shape[0]):
            ray_origins = centroid[b].unsqueeze(0).expand(angles.shape[0], -1)  # Shape: (N, 2)

            # Now create the ray caster with consistent dimensions
            epi_ray = VectorizedRayCaster(ray_origins, angles, 1.0)
            intersection_epi = epi_ray.cast_ray(epi_boundary[b][0], mask_value_threshold=0.5, max_steps=75).to(device)

            endo_ray = VectorizedRayCaster(ray_origins, angles, 1.0)
            intersection_endo = endo_ray.cast_ray(endo_boundary[b][0], mask_value_threshold=0.5, max_steps=75).to(device)

            diff = (intersection_epi - intersection_endo)
            norm = torch.linalg.norm(diff, dim=1)
            diff_norms.append(norm)

        pred_dist = torch.stack(diff_norms).requires_grad_()
        print("Gt and pred shapes", gt_dist.shape, pred_dist.shape)

        gt_dist = gt_dist.to(device)
        pred_dist = pred_dist.to(device)

        loss = criterion(gt_dist, pred_dist)

    # Backward pass and optimization
    loss.backward()
    optimizer.step()

    # Accumulate training loss
    train_loss += loss.item()

# Average training loss for the epoch
avg_train_loss = train_loss / len(train_loader)

train_losses.append(avg_train_loss)

print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {avg_train_loss:.4f}")

Some generic debugging advice:

  • checking whether tensors in forward require_grad=True to identify where the autograd graph might be broken. A common reason autograd graph is broken if you do things like numpy/create fresh tensors that aren’t tracked by autograd.
  • using hooks Autograd mechanics — PyTorch 2.6 documentation to introspect values of the gradients as they flow backward
  • using TORCH_LOGS=“+autograd” to observe the flow of backward execution and checking where it stops.