I faced a peculiar situation in which I want to sum the elements of a matrix W that resides on MPS using W.sum()
. I will describe the problem first, and then I will try to describe the exact operations that resulted in this matrix at the end of this post.
The problem is that W.sum()
returns a value that is clearly false (clearly in my case). For example, we have the following:
-
W.sum() == W.view(-1).sum()
returns False. -
W.view(-1).view(*W.shape).sum() == W.view(-1).sum()
returns False. -
W.sum() == W.cpu().sum()
returns False. -
W.cpu().to('mps').sum() == W.cpu().sum()
returns False!
In other words, we have:
-
W.sum() == W.view(-1).view(*W.shape).sum() == W.cpu().to('mps').sum() == wrong_value
, and -
W.view(-1).sum() == W.cpu().sum() == W.sum(1).sum() == W.sum(0).sum() == true_value
.
This behavior extends to many other functions that reduce a matrix to a single value (e.g. norm()
and count_nonzero()
), so something is clearly off with the matrix view of W on MPS device.
The code that produced such a matrix in my case can be summarized as follows: for each param in a model, create W a matrix view of the param, create W_new a new matrix like W, run some operations on W_new, return param_new a param view of W_new, and copy it onto param param.copy_(param_new)
. This all runs under torch.no_grad()
context.