Weird bug: how does unsqueeze and flatten affect gradients?

Problem Description:
I have a tensor named returns with shape [bsz] and another tensor v_new with shape [bsz, 1]. I am calculating the value function loss (vf_loss) in two different ways, and I’m observing different convergence behaviors for each method.

Here’s the relevant part of my code:

# Tensor shapes: returns [bsz], v_new [bsz,1]

# Method 1
aa = (returns.unsqueeze(-1) - v_new)  # After unsqueeze, aa has shape [bsz, 1]
vf_loss_aa = aa.pow(2).mean()

# Method 2
bb = (returns - v_new.flatten())  # v_new flattened to [bsz], bb has shape [bsz]
vf_loss_bb = bb.pow(2).mean()

Observations:

  • When using vf_loss_aa (computed from aa), the model does not converge.
  • However, with vf_loss_bb (computed from bb), the model converges normally.

Points to Consider:

  • Logically, aa and bb should represent the same values, albeit with different shapes. aa is 2D ([bsz, 1]), and bb is 1D ([bsz]).
  • The .pow(2).mean() operation should yield the same result for both aa and bb if the values are the same.
  • I’ve checked that aa.flatten() and bb are equivalent.

Questions:

  1. Why would there be a difference in convergence behavior between these two methods, given that the operations and the resulting values should theoretically be the same?
  2. Could this discrepancy be related to how PyTorch handles gradients for tensors of different shapes?
  3. Are there any known issues or subtleties in PyTorch related to this kind of situation?

Any insights, explanations, or suggestions for further debugging would be greatly appreciated. I’m curious to understand the underlying reason for this difference in behavior.

Thank you in advance for your help!

This is my bad, I accidentally changed the shape of a var which assumed to be a different shape in subsequent functions and lead to this bug.

For anyone hits similar issues: this must be a bug in your code. pls look closely: There shouldn’t be any difference using pytorch’s reshape view flatten squeeze etc funcs w.r.t. gradients.