Removing Return Statement in Module Forward Causes 30+ms Backward Slowdown - Why?

Hello community members,

We recently encountered a puzzling performance issue during training and would appreciate your insights. The core problem is: ​​Removing a return statement in our module’s forward() method leads to a 30+ms slowdown in backward propagation, despite no logical changes to code execution.​​ Here’s the context:

def forward(self, batch_extrinsic):  
    if self._cache_bev_to_camera_projection is None:  # Cache-initialization branch  
        with torch.no_grad():  
            # ~20 tensor operations for projection/mask generation  
            bev_to_camera_projection, bev_mask = self.project_lidar_to_camera(...)  
            # Cache results  
            self._cache_bev_to_camera_projection = bev_to_camera_projection[0:1]  
            self._cache_bev_mask = bev_mask[0:1]  
        return bev_to_camera_projection, bev_mask  # Removing this line slows backward  

    # Cached branch  
    return self._cache_bev_to_camera_projection.repeat(...), self._cache_bev_mask.repeat(...)

Key Observations

  1. ​With return statement​​:
  • The cache-initialization branch runs ​​only once​​ (first forward pass).
  • Backward pass time is normal.
  1. ​Without return statement​​:
  • Cache initialization still works correctly.
  • ​Backward pass slows down by 30ms​​, despite identical forward logic.
  1. ​Workaround​​:
  • Rewriting with if-else (instead of early return) restores normal backward performance.

Hypotheses & Questions

We suspect this relates to ​​PyTorch’s computation graph construction​​, but need expert input:

  1. ​Gradient Tracking​​:
  • Despite torch.no_grad(), could returning the uncached tensors create unintended gradient dependencies?
  • Does repeat() on cached tensors retain graph connections to the initial computation?
  1. ​Graph Optimization​​:
  • Does the early return help PyTorch prune unused graph branches during backward?
  • Why does the if-else structure resolve the slowdown?
  1. ​Autograd Behavior​​:
  • How does conditional branching affect graph retention/rebuilding overhead?

Any insights into why removing a logically redundant return impacts backward performance would be greatly appreciated!

Thank you!

Could you post a minimal and executable code snippet reproducing this issue, please?