Inconsistent results when printing variables

Hi,

I ran into a really peculiar situation here.
I have a pretrained network (weights frozen in training) that is supposed to take plane sweep volume (essentially stack of warped images) as input and produce intermediate results for other modules in the pipeline.
My plan is to run the said network twice in a training step.
However, for some unknown reason, the second pass of the network would result in all-zero output.
Here is a code snippet of what happened.

# Initialization
self.net = Model()
load_ckpts(self.net, path)
self.net.eval()

...
# Training step
with torch.no_grad():
                result1 = self.gen_result(self.net, input_imgs1, input_exts, input_ints, depths)               
                result2 = self.gen_result(self.net, input_imgs2, input_exts, input_ints, depths) # result2 is all zero for some reason

Some scenarios:

  1. Run only result2
    If I comment out the first line which generates result1, then result2 would be nonzero values as expected.

  2. Run both result1 and result2
    If I add a printing function between two lines and print out input_imgs2, then result2 is nonzero as expected.
    However, if I move the printing function to after result2, then result2 would become all zeroes.

I have never run into similar issues before. It does not make sense to me how printing out the variables would ever change the results.

I tested line-by-line in the self.gen_result function and found out that this line could cause the issue:

trnfs = torch.matmul(src_exts, torch.inverse(tgt_exts))

If I print input_imgs2 or src_exts before this line, then it would work.
And the opposite would produce an all-zero output.

Edit:
Environment: Windows 10
PyTorch 1.7.1 (also tried 1.7.0)

This sounds like a synchronization issue.
Could you post an executable code snippet to reproduce this issue?

Thanks for the reply.
I tried to isolate the network code from the training pipeline to create a code snippet.
It seems like without the training pipeline it would perform normally.
One thing I did not mention is that I was using PyTorch-Lightning as the base structure.
Could it be a possible cause?
I am running on a single GPU, so I am not sure what you meant by synchronization issue.
Or you meant the synchronization between CPU and GPU?

Edit:
I just found out that I was using CPU for the test. If I switch to GPU, then the error would appear again.

Further debugging leads to the operations applying torch.repeat on the input tensor.
I found that if I don’t use torch.repeat then it would function properly.
I was running something like this

psv_src = input_exts.reshape([-1, 4, 4]) # input_exts is [batch, #views, 4, 4]
psv_tgt = input_exts[:, 0:1].repeat([1, views, 1, 1]).reshape([-1, 4, 4])

trnfs = torch.matmul(psv_src, torch.inverse(psv_tgt))

So it looks like the repeat operation breaks something?
This only happens when the function is called again. The first time is working as expected.

There could be an issue in the repeat function and an executable code snippet or any more pointers how to reproduce this issue, would be very helpful. :slight_smile:

Hi,
Sorry it took a while to clean up other parts.
Please download the content here and run reproduce_bug.py.
Uncomment/comment line 104 to see the bug.
Basically, the output of self.create_psv would change w/ or w/o the printing function.
Ideally, the output artifacts in psv folder should be images with nonzero values.
https://drive.google.com/drive/folders/1dvLllgq01oLWqfLZHZgrYuicQy04-88z?usp=sharing

@ptrblck
Any update on this?
I also tested the script on another Linux machine and the result is the same.
When GPU is enabled, it would produce this weird behavior.
And interestingly, it only happens when the network is involved.
I can run the code without the network (i.e. remove line 134~144) and the result is fine on GPU.
I also tried creating a new tensor instead of using repeat function and it had the same issue.

Additionally, I received an error on Linux when trying to save the images.
It says “ValueError: ndarray is not C-contiguous”. However, this error is not presented on Windows machine.
Could the contiguous property be the cause because of the reshape function?

Thanks.

No updates so far. I try to free some time later today to take a look at all script files, as it doesn’t seem to be trivial to reproduce it given the files.
Thanks btw. for creating them! :wink:

1 Like

Could you post the shapes or code to create all needed npy files, please?

Here are the shapes

tgt_bg # (#channels, height, width)
src_bg # (#views, #channels, height, width)
tgt_rgb # (#frames, #channels, height, width) I think this is not used in the code
src_rgb # (#views, #frames, #channels, height, width) this is the main input
poses_bounds # (1, #cameras, 17)

So #views is a subset of the cameras and #cameras is the total camera count.
#frames is only 1 for now.

For poses_bounds, it is a bit more complicated.
17 is a flattened 3x5 array with 2 additional parameters denoting the near and far plane of the camera.
What I would do is to acquire the 3x5 array with poses_bounds[..., :-2].reshape([-1, 3, 5])
The 3x5 array is the camera extrinsic matrix [R|t] concatenated with np.array([height, width, focal_length]) as its fifth column.

I am not sure if this is what you are looking for?
I pulled the image npy files directly from a larger dataset, so it would be a bit more complex to look into the processing code for that.

Another note is that my labmate and I discovered we could put torch.cuda.synchronize() before
fg_mpi = self.gen_mpi(self.mpi_net, src_imgs[:, :, 0], src_exts, src_ints, depths) and it would solve the issue.
Does this sound like a possible bug in GPU synchronization?
Thanks.

Thanks for the update. I’ll try to work on the reproduction.

Could be, but it’s too early to tell without proper debugging.

In the meantime, could you clone the numpy arrays to check, if you are hitting this issue?

Doesn’t seem like it.
I didn’t use dataloader in the reproducible code.
I also tried adding .copy() to each numpy array (e.g. after np.load(npy)) but it doesn’t change anything.
Or did I do it wrong?

Based on the posted shapes, I’m initializing the data via:

    channels = 3
    height, width = 224, 224
    views = 2
    cameras = 4
    frames = 1

    #loaded = torch.FloatTensor(np.load('poses_bounds.npy'))
    poses_bounds = torch.randn(1, cameras, 17)#[proc_poses_bounds(x) for x in loaded]
    K = torch.stack([x[0] for x in poses_bounds])
    w2c = torch.stack([x[1] for x in poses_bounds])
    bds = torch.stack([x[2] for x in poses_bounds])

    #tgt_rgb = torch.FloatTensor(np.load('tgt_rgb.npy'))
    tgt_rgb = torch.randn(channels, height, width)
    src_rgb = torch.randn(views, channels, height, width) #torch.FloatTensor(np.load('src_rgb.npy'))
    tgt_bg = torch.randn(frames, channels, height, width) #torch.FloatTensor(np.load('tgt_bg.npy'))
    src_bg = torch.randn(views, frames, channels, height, width) #torch.FloatTensor(np.load('src_bg.npy'))

which raises:

Traceback (most recent call last):
  File "reproduce_bug.py", line 190, in <module>
    net(sample)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 881, in _call_impl
    result = self.forward(*input, **kwargs)
  File "reproduce_bug.py", line 74, in forward
    bg_mpi = self.gen_mpi(self.mpi_net, src_bg, src_exts, src_ints, depths)
  File "reproduce_bug.py", line 114, in gen_mpi
    b, v, c, h, w = src_imgs.shape
ValueError: too many values to unpack (expected 5)

I have thus removed the unsqueeze from src_bg, but get the next error:

Traceback (most recent call last):
  File "reproduce_bug.py", line 191, in <module>
    net(sample)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 881, in _call_impl
    result = self.forward(*input, **kwargs)
  File "reproduce_bug.py", line 74, in forward
    bg_mpi = self.gen_mpi(self.mpi_net, src_bg, src_exts, src_ints, depths)
  File "reproduce_bug.py", line 119, in gen_mpi
    psv_src = src_exts.reshape([-1, 4, 4])
RuntimeError: shape '[-1, 4, 4]' is invalid for input of size 2

Could you check the shapes again, please?

Maybe try this? And directly use sample as the input to the network.

channels = 3
height, width = 224, 224
views = 2
cameras = 4
frames = 1

tgt_rgb = torch.randn(1, frames, channels, height, width)
src_rgb = torch.randn(1, views, frames, channels, height, width)
tgt_bg = torch.randn(1, channels, height, width)
src_bg = torch.randn(1, views, channels, height, width)

src_w2c = torch.randn(1, views, 4, 4)
src_K = torch.randn(1, views, 3, 3)
tgt_w2c = torch.randn(1, 4, 4)
tgt_K = torch.randn(1, 3, 3)
bd = torch.FloatTensor([1, 100])

sample = dict()
# Pack data
sample['tgt_rgb'] = tgt_rgb.to(device)
sample['src_rgb'] = src_rgb.to(device)

sample['src_bg'] = src_bg.to(device)
sample['tgt_bg'] = tgt_bg.to(device)

sample['src_w2c'] = src_w2c.to(device)
sample['src_K'] = src_K.to(device)
sample['tgt_w2c'] = tgt_w2c.to(device)
sample['tgt_K'] = tgt_K.to(device)

sample['bd'] = torch.FloatTensor([
    torch.min(bds[0]) * .9,
    torch.max(bds[0]) * 2.,
]).unsqueeze(0).to(device)

However, one thing is that randomly initializing everything could produce weird outcome since the calculation is based on projective geometry.

I’m able to run the code until the numpy error is raised:

tensor(14018246., device='cuda:0')
torch.Size([2, 3, 32, 180, 320])
max: tensor(1.9529, device='cuda:0')
max: tensor(0.9529, device='cuda:0')
tensor(13985083., device='cuda:0')
torch.Size([2, 3, 32, 180, 320])
max: tensor(1.9490, device='cuda:0')
max: tensor(0.9490, device='cuda:0')
input to fg_mpi:  tensor(113580.9062, device='cuda:0')
Traceback (most recent call last):
  File "reproduce_bug.py", line 194, in <module>
    net(sample)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 881, in _call_impl
    result = self.forward(*input, **kwargs)
  File "reproduce_bug.py", line 78, in forward
    save_rgba(fg_mpi.squeeze(), 'rgba')
  File "/workspace/src/visualization.py", line 19, in save_rgba
    plt.imsave(os.path.join(folder, f'rgba_{i:03d}.png'), tmp.cpu().numpy())
  File "/opt/conda/lib/python3.8/site-packages/matplotlib/pyplot.py", line 2251, in imsave
    return matplotlib.image.imsave(fname, arr, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/matplotlib/image.py", line 1581, in imsave
    image = PIL.Image.frombuffer(
  File "/opt/conda/lib/python3.8/site-packages/PIL/Image.py", line 2637, in frombuffer
    im = im._new(core.map_buffer(data, size, decoder_name, None, 0, args))
ValueError: ndarray is not C-contiguous

Could you let me know, which tensors are supposed to show the potential race condition for further debugging?

I encountered similar error on Linux machine, but not on Windows machine.
Could you simply comment out line 78: save_rgba(fg_mpi.squeeze(), 'rgba')?
It shouldn’t have much to do with the main issue at hand.
As for your question, in the line trnfs = torch.matmul(src_exts, torch.inverse(tgt_exts)), trnfs and the subsequent psv tensors are the ones in question.
The main issue is that psv tensor becomes zero when it shouldn’t.
One indicator is using the second set of printed max value of psv tensor.
In your output,

max: tensor(1.9490, device='cuda:0')
max: tensor(0.9490, device='cuda:0')

this is the desired value.
When the error appears, it would become something like this.

max: tensor(0., device='cuda:0')
max: tensor(-1., device='cuda:0')

Maybe it has something to do with CUDA version?
My environment is CUDA 10.1 on both the Linux and Windows machines mentioned.
I just ran a quick test on CUDA 11.2 and it seems to perform okay without torch.cuda.synchronize().
Could you confirm you are not running 10.1?
Thanks.
If this is indeed the issue, then I think I could close the thread.

Thanks again for the code snippet as well as the great debugging!
I was able to narrow it down to a sync issue using this “minimal” code snippet:

x = torch.randn(1, 2, 4, 4, device='cuda')
v = 2
m = torch.randn(1024, 1024, device='cuda')

res = []
for _ in range(10):
    psv_tgt = x[:, 0:1].repeat([1, v, 1, 1]).reshape([-1, 4, 4])
    res.append(torch.inverse(psv_tgt))
    for _ in range(5):
        torch.matmul(m, m)

res = torch.stack(res)
print(res)

The pseudo workload is necessary to show the sync issue, otherwise you’ll get valid outputs.
The issue is created by a wrong sync pattern in apply_batched_inverse_lib, which was already fixed in this PR.
It’s good to know that the code works fine with 11.2, but I assume you’ve built PyTorch from source or installed the 1.8rc for it (which would thus already include the fix).
I’ve also verified that your code works fine using the nightly binary with CUDA10.1, so I’m quite sure you were also hitting the aforementioned issue.

Thanks for the help and your time! @ptrblck
Yes, I did build PyTorch from source.
I am glad this is solved. I will close this topic now.