Hello community members,
We recently encountered a puzzling performance issue during training and would appreciate your insights. The core problem is: Removing a return
statement in our module’s forward()
method leads to a 30+ms slowdown in backward propagation, despite no logical changes to code execution. Here’s the context:
def forward(self, batch_extrinsic):
if self._cache_bev_to_camera_projection is None: # Cache-initialization branch
with torch.no_grad():
# ~20 tensor operations for projection/mask generation
bev_to_camera_projection, bev_mask = self.project_lidar_to_camera(...)
# Cache results
self._cache_bev_to_camera_projection = bev_to_camera_projection[0:1]
self._cache_bev_mask = bev_mask[0:1]
return bev_to_camera_projection, bev_mask # Removing this line slows backward
# Cached branch
return self._cache_bev_to_camera_projection.repeat(...), self._cache_bev_mask.repeat(...)
Key Observations
- With
return
statement:
- The cache-initialization branch runs only once (first forward pass).
- Backward pass time is normal.
- Without
return
statement:
- Cache initialization still works correctly.
- Backward pass slows down by 30ms, despite identical forward logic.
- Workaround:
- Rewriting with
if-else
(instead of earlyreturn
) restores normal backward performance.
Hypotheses & Questions
We suspect this relates to PyTorch’s computation graph construction, but need expert input:
- Gradient Tracking:
- Despite
torch.no_grad()
, could returning the uncached tensors create unintended gradient dependencies? - Does
repeat()
on cached tensors retain graph connections to the initial computation?
- Graph Optimization:
- Does the early
return
help PyTorch prune unused graph branches during backward? - Why does the
if-else
structure resolve the slowdown?
- Autograd Behavior:
- How does conditional branching affect graph retention/rebuilding overhead?
Any insights into why removing a logically redundant return
impacts backward performance would be greatly appreciated!
Thank you!