Call by object reference, tensor object

def flatten(t):
    t = t.reshape(1, -1)
    t = t.squeeze()
    t[5] = 5
    return t

t = torch.ones(4, 3)
flatten(t)
print(t)

I couldn’t understand why the result is tensor([[1., 1., 1.],
[1., 1., 5.],
[1., 1., 1.],
[1., 1., 1.]]). I expected it to stay the original matrix of 1’s, because in the function, t is just a copy to the reference to the object. And after the first command: “t = t.reshape(1, -1)”, t is not referencing to a new object. (another object). the same happens to the next line. and when we try to change t (t[5] = 5), we changed a different object (because as said t is now referencing a new, another object).
so how come t changed?

Tensors are stored in memory as a single contiguous block of memory. It may be a straight line. Every tensor has two other components dimension and stride. Dim tells the size of the tensor in each dimension and stride tells the number of bytes you need to step in each dimension when traversing the array.

Reshape function will return a view i.e. in your case it would just change the dimension with the same underlying data (no change in bytes in memory and the same bytes are returned). Same is true for squeeze. Now when you run t[5]=5, it is changing the same underlying data that ‘t’ originally pointed to. For clarity the below code, gives the same output as yours.

def flatten(t):
    t = t.reshape(1, -1)
    t = t.squeeze()
    t[5] = 5
    return t

x = torch.ones(4,3)
flatten(x)
print(x)
1 Like