Any way to check if two tensors have the same base

I’m not sure I fully understood your question, but I’ll try to answer:

import torch
x = torch.randn(4, 4)
y = x.view(2,-1)
print(x.data_ptr() == y.data_ptr()) # prints True
y = x.clone().view(2,-1)
print(x.data_ptr() == y.data_ptr()) # prints False

But it doesn’t work if you are interested in comparing tensor storage. For instance, in the example below, x and y share storage, but they don’t share the same data pointer:

import torch
x = torch.arange(10)
y = x[1::2]
print(x.data_ptr() == y.data_ptr()) # prints False

The following snippet checks whether two tensors share storage, though, I guess there might be a more efficient way of doing this.

import torch
def same_storage(x, y):
	x_ptrs = set(e.data_ptr() for e in x.view(-1))
	y_ptrs = set(e.data_ptr() for e in y.view(-1))
	return (x_ptrs <= y_ptrs) or (y_ptrs <= x_ptrs)

x = torch.arange(10)
y = x[1::2]
print(same_storage(x, y)) # prints True
z = y.clone()
print(same_storage(x, z)) # prints False
print(same_storage(y, z)) # prints False
7 Likes