Discrepancy in sum of matrix on mps and cpu (or its matrix view and vector view)

I faced a peculiar situation in which I want to sum the elements of a matrix W that resides on MPS using W.sum(). I will describe the problem first, and then I will try to describe the exact operations that resulted in this matrix at the end of this post.

The problem is that W.sum() returns a value that is clearly false (clearly in my case). For example, we have the following:

  1. W.sum() == W.view(-1).sum() returns False.
  2. W.view(-1).view(*W.shape).sum() == W.view(-1).sum() returns False.
  3. W.sum() == W.cpu().sum() returns False.
  4. W.cpu().to('mps').sum() == W.cpu().sum() returns False!

In other words, we have:

  • W.sum() == W.view(-1).view(*W.shape).sum() == W.cpu().to('mps').sum() == wrong_value, and
  • W.view(-1).sum() == W.cpu().sum() == W.sum(1).sum() == W.sum(0).sum() == true_value.

This behavior extends to many other functions that reduce a matrix to a single value (e.g. norm() and count_nonzero()), so something is clearly off with the matrix view of W on MPS device.

The code that produced such a matrix in my case can be summarized as follows: for each param in a model, create W a matrix view of the param, create W_new a new matrix like W, run some operations on W_new, return param_new a param view of W_new, and copy it onto param param.copy_(param_new). This all runs under torch.no_grad() context.

I assume you are comparing floating point values directly, which might fail due to their limited precision and a potentially different order of operations.
Here is a small example using sum and forcing a change in the order of operations:

x = torch.randn(100, 100)
s1 = x.sum()
s2 = x.sum(0).sum()
print((s1 - s2).abs().max())
# tensor(1.5259e-05)

How large are the errors you are seeing in your use case?

Sorry for the late reply. The summation itself is not the issue here. Honestly, I was counting non-zeros, so floating point might not explain the issue (the summation example still holds). Also, floating points don’t explain the slightly different number of nonzeros I get when I simply move my model to cpu (or do they? I assume cpu case produces the true value). I will try to come up with a reproducible example but it’s very tricky as I do not exactly know what led to this behavior during training. FYI I have a MacBook Pro (16-inch, 2021) with Apple M1 Pro and PyTorch 2.0 installed.