I faced a peculiar situation in which I want to sum the elements of a matrix W that resides on MPS using
W.sum(). I will describe the problem first, and then I will try to describe the exact operations that resulted in this matrix at the end of this post.
The problem is that
W.sum() returns a value that is clearly false (clearly in my case). For example, we have the following:
W.sum() == W.view(-1).sum()returns False.
W.view(-1).view(*W.shape).sum() == W.view(-1).sum()returns False.
W.sum() == W.cpu().sum()returns False.
W.cpu().to('mps').sum() == W.cpu().sum()returns False!
In other words, we have:
W.sum() == W.view(-1).view(*W.shape).sum() == W.cpu().to('mps').sum() == wrong_value, and
W.view(-1).sum() == W.cpu().sum() == W.sum(1).sum() == W.sum(0).sum() == true_value.
This behavior extends to many other functions that reduce a matrix to a single value (e.g.
count_nonzero()), so something is clearly off with the matrix view of W on MPS device.
The code that produced such a matrix in my case can be summarized as follows: for each param in a model, create W a matrix view of the param, create W_new a new matrix like W, run some operations on W_new, return param_new a param view of W_new, and copy it onto param
param.copy_(param_new). This all runs under